diff --git a/docs/source/models/dlrm_hstu.md b/docs/source/models/dlrm_hstu.md index 40f0d3d6..afcd2e92 100644 --- a/docs/source/models/dlrm_hstu.md +++ b/docs/source/models/dlrm_hstu.md @@ -165,7 +165,11 @@ model_config { - metrics: 任务指标 - max_seq_len: 最大序列长度 -- kernel: 算子实现,可选TRITON/PYTORCH,TRITON通常比PYTORCH快2-3x,节省2-3x显存 +- kernel: 算子实现,可选TRITON/PYTORCH/CUTLASS + + - TRITON: 基于Triton的实现,通常比PYTORCH快2-3x,节省2-3x显存 + - CUTLASS: 基于CUTLASS的CUDA融合算子实现,需安装hstu_attn包(DEVICE可选cu126/cu129:`pip install hstu_attn-0.1.0+bea6b4b.${DEVICE} -f https://tzrec.oss-accelerate.aliyuncs.com/third_party/hstu/${DEVICE}/repo.html`),要求`attention_dim`等于`hidden_dim`,支持Ampere/Ada/Hopper GPU + - PYTORCH: 纯PyTorch实现,兼容性最好 ### MTGR Style 配置方式 diff --git a/requirements/extra.txt b/requirements/extra.txt index 76086107..6075f7bd 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -1,4 +1,7 @@ dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260407.97b80bf.cu129-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260407.97b80bf.cu129-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260407.97b80bf.cu129-cp312-cp312-linux_x86_64.whl ; python_version=="3.12" +hstu_attn @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" +hstu_attn @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" +hstu_attn @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp312-cp312-linux_x86_64.whl ; python_version=="3.12" torch_fx_tool @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/rtp/torch_fx_tool-0.0.1%2B20251201.8c109c4-py3-none-any.whl diff --git a/tzrec/acc/aot_utils.py b/tzrec/acc/aot_utils.py index 45290fee..32ec50f9 100644 --- a/tzrec/acc/aot_utils.py +++ b/tzrec/acc/aot_utils.py @@ -11,15 +11,26 @@ import os -from typing import Any, Dict +from typing import Any, Dict, Optional import torch from torch import nn -from tzrec.models.model import CombinedModelWrapper +from tzrec.models.model import CombinedModelWrapper, DenseAutocastWrapper from tzrec.utils.fx_util import symbolic_trace from tzrec.utils.logging_util import logger +# Eagerly register custom ops referenced by AOT-packaged models so that +# torch._inductor.aoti_load_package() can resolve them by name. AOT packages +# reference ops via their qualified name (e.g. ``tzrec::cutlass_hstu_mha_fwd``) +# and PyTorch only knows about an op once its registering module has been +# imported. Wrap in try/except so this stays optional for environments +# without the corresponding native dependencies installed. +try: + from tzrec.ops._cuda import cutlass_hstu_attention # noqa: F401 +except ImportError: + pass + def load_model_aot(model_path: str, device: torch.device) -> CombinedModelWrapper: """Load AOTInductor model. @@ -49,6 +60,7 @@ def export_model_aot( data: Dict[str, torch.Tensor], meta_info: Dict[str, Any], save_dir: str, + mixed_precision: Optional[str] = None, ) -> str: """Export AOTInductor model. @@ -58,6 +70,12 @@ def export_model_aot( data (Dict[str, torch.Tensor]): the test data meta_info (Dict[str, Any]): split meta info save_dir (str): model save dir + mixed_precision (Optional[str]): "BF16", "FP16", or None. When set, + the dense sub-graph is wrapped in a DenseAutocastWrapper so that + torch.export captures the autocast region as a wrap_with_autocast + Higher Order Op. The sparse sub-graph is left untouched because + it is only embedding lookups, which don't benefit from AMP and + which would complicate torch.jit.script compilation. """ sparse_output, _ = sparse_model(data, "cuda:0") sparse_model_traced = symbolic_trace(sparse_model) @@ -86,12 +104,20 @@ def export_model_aot( logger.info("dynamic shapes=%s" % dynamic_shapes) + # Wrap the dense module so torch.export captures the autocast region + # as a `wrap_with_autocast` HOP that AOT Inductor lowers correctly. + dense_to_export: nn.Module = dense_model + if mixed_precision: + dense_to_export = DenseAutocastWrapper(dense_model, mixed_precision) + # pre_hook requires running arbitrary code at runtime with torch._inductor.config.patch( {"unsafe_ignore_unsupported_triton_autotune_args": True} ): exported_pg = torch.export.export( - dense_model, args=(sparse_output,), dynamic_shapes=(dynamic_shapes,) + dense_to_export, + args=(sparse_output,), + dynamic_shapes=(dynamic_shapes,), ) # AsserScalar codegen is not correct. with torch._inductor.config.patch( diff --git a/tzrec/acc/trt_utils.py b/tzrec/acc/trt_utils.py index 4ef7feb4..6aa1b3ee 100644 --- a/tzrec/acc/trt_utils.py +++ b/tzrec/acc/trt_utils.py @@ -18,7 +18,7 @@ from torch.profiler import ProfilerActivity, profile, record_function from tzrec.acc.utils import get_max_export_batch_size, is_debug_trt -from tzrec.models.model import CombinedModelWrapper +from tzrec.models.model import CombinedModelWrapper, DenseAutocastWrapper from tzrec.utils.fx_util import symbolic_trace from tzrec.utils.logging_util import logger @@ -104,6 +104,7 @@ def export_model_trt( dense_model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str, + mixed_precision: Optional[str] = None, ) -> None: """Export trt model. @@ -112,6 +113,10 @@ def export_model_trt( dense_model (nn.Module): the dense part data (Dict[str, torch.Tensor]): the test data save_dir (str): model save dir + mixed_precision (Optional[str]): "BF16", "FP16", or None. When set, + only the dense sub-graph is wrapped in a DenseAutocastWrapper + before torch.export. The sparse sub-graph is only embedding + lookups and doesn't benefit from AMP. """ emb_ebc, _ = sparse_model(data, "cuda:0") sparse_model_traced = symbolic_trace(sparse_model) @@ -153,8 +158,11 @@ def export_model_trt( for i, k in enumerate(key_list): dynamic_shapes[dense_arg_name].update({k: dynamic_shapes_list[i]}) + dense_to_export: nn.Module = dense_layer + if mixed_precision: + dense_to_export = DenseAutocastWrapper(dense_layer, mixed_precision) exp_program = torch.export.export( - dense_layer, (emb_ebc,), dynamic_shapes=dynamic_shapes + dense_to_export, (emb_ebc,), dynamic_shapes=dynamic_shapes ) dense_layer_trt = trt_convert(exp_program, (emb_ebc,)) # logger.info("dense trt res: %s", dense_layer_trt(emb_ebc)) diff --git a/tzrec/acc/utils.py b/tzrec/acc/utils.py index 9480ef03..bd783be9 100644 --- a/tzrec/acc/utils.py +++ b/tzrec/acc/utils.py @@ -12,10 +12,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import json import os -from typing import Dict +from typing import Dict, Optional import torch +from tzrec.protos.pipeline_pb2 import EasyRecConfig from tzrec.protos.train_pb2 import TrainConfig @@ -165,6 +166,41 @@ def ec_quant_dtype() -> torch.dtype: return _quant_str_to_dtype[quant_dtype_str] +_MIXED_PRECISION_TO_DTYPE: Dict[str, torch.dtype] = { + "BF16": torch.bfloat16, + "FP16": torch.float16, +} + + +def mixed_precision_to_dtype(mixed_precision: Optional[str]) -> Optional[torch.dtype]: + """Convert a TrainConfig.mixed_precision string to a torch dtype. + + Returns ``None`` when ``mixed_precision`` is ``None`` or empty. + Raises ``ValueError`` on unknown values so typos fail loudly. + """ + if not mixed_precision: + return None + if mixed_precision not in _MIXED_PRECISION_TO_DTYPE: + raise ValueError( + f"Unknown mixed_precision: {mixed_precision}, " + f"available types: {list(_MIXED_PRECISION_TO_DTYPE.keys())}" + ) + return _MIXED_PRECISION_TO_DTYPE[mixed_precision] + + +def resolve_mixed_precision(pipeline_config: EasyRecConfig) -> str: + """Resolve the mixed_precision mode for export/inference. + + Precedence: ``export_config.mixed_precision`` (when set) overrides + ``train_config.mixed_precision``. Empty string means no AMP. + """ + if pipeline_config.HasField("export_config"): + export_mp = pipeline_config.export_config.mixed_precision + if export_mp: + return export_mp + return pipeline_config.train_config.mixed_precision + + def write_mapping_file_for_input_tile( state_dict: Dict[str, torch.Tensor], remap_file_path: str ) -> None: diff --git a/tzrec/models/model.py b/tzrec/models/model.py index cb8209ce..f9ce7eb1 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -12,7 +12,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from collections import OrderedDict from queue import Queue -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Final, Iterable, List, Optional, Tuple import torch import torchmetrics @@ -22,6 +22,7 @@ EmbeddingCollectionInterface, ) +from tzrec.acc import utils as acc_utils from tzrec.constant import TARGET_REPEAT_INTERLEAVE_KEY from tzrec.datasets.data_parser import DataParser from tzrec.datasets.utils import Batch @@ -236,16 +237,7 @@ def __init__( self._device_type = "cpu" if device is not None: self._device_type = device.type - if mixed_precision is None or len(mixed_precision) == 0: - self._mixed_dtype = None - elif mixed_precision == "FP16": - self._mixed_dtype = torch.float16 - elif mixed_precision == "BF16": - self._mixed_dtype = torch.bfloat16 - else: - raise ValueError( - f"mixed_precision should be FP16 or BF16, but got [{mixed_precision}]" - ) + self._mixed_dtype = acc_utils.mixed_precision_to_dtype(mixed_precision) self.pareto = None if ( hasattr(self.model, "_use_pareto_loss_weight") @@ -300,16 +292,7 @@ def __init__( self._device_type = "cpu" if device is not None: self._device_type = device.type - if mixed_precision is None or len(mixed_precision) == 0: - self._mixed_dtype = None - elif mixed_precision == "FP16": - self._mixed_dtype = torch.float16 - elif mixed_precision == "BF16": - self._mixed_dtype = torch.bfloat16 - else: - raise ValueError( - f"mixed_precision should be FP16 or BF16, but got [{mixed_precision}]" - ) + self._mixed_dtype = acc_utils.mixed_precision_to_dtype(mixed_precision) self._output_cols = output_cols def forward( @@ -389,6 +372,57 @@ def forward( return self.model.predict(batch) +class DenseAutocastWrapper(nn.Module): + """Wraps a dense-side module in a torch.autocast context for torch.export. + + Used to wrap the dense GraphModule from split_model before passing to + torch.export.export. torch.export captures the `with torch.autocast(...)` + region as a `wrap_with_autocast` Higher Order Op in the exported graph, + which AOT Inductor lowers to proper dtype casts. + + Only the dense path needs this wrapper: the sparse path is embedding + lookups which don't have autocast rules and don't benefit from AMP, and + the CUTLASS attention op lives in the dense sub-graph after the + sparse/dense split. + + The forward takes a single dict argument matching the dense_gm signature. + + Note on the ``_mixed_dtype_id: Final[int]`` encoding: ``torch.jit.script`` + only honors ``torch.autocast(dtype=...)`` when ``dtype`` is a compile-time + literal. We therefore store an integer flag (``Final[int]``) set in + ``__init__`` and branch on it in ``forward`` so each branch's + ``torch.autocast(dtype=torch.bfloat16|torch.float16)`` is a literal. + Storing the dtype directly as an attribute would not be script-friendly. + + Args: + inner (nn.Module): inner dense module to wrap. + mixed_precision (Optional[str]): one of "BF16", "FP16", or None. + """ + + _mixed_dtype_id: Final[int] + + def __init__(self, inner: nn.Module, mixed_precision: Optional[str] = None) -> None: + super().__init__() + self.inner = inner + if mixed_precision == "BF16": + self._mixed_dtype_id = 1 + elif mixed_precision == "FP16": + self._mixed_dtype_id = 2 + else: + self._mixed_dtype_id = 0 + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Forward through inner module under an autocast context.""" + if self._mixed_dtype_id == 1: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + return self.inner(x) + elif self._mixed_dtype_id == 2: + with torch.autocast(device_type="cuda", dtype=torch.float16): + return self.inner(x) + else: + return self.inner(x) + + class CombinedModelWrapper(nn.Module): """Model inference wrapper for model combined with sparse and dense part. diff --git a/tzrec/ops/__init__.py b/tzrec/ops/__init__.py index 5bf9df9a..fe62132b 100644 --- a/tzrec/ops/__init__.py +++ b/tzrec/ops/__init__.py @@ -18,4 +18,4 @@ class Kernel(Enum): TRITON = "TRITON" PYTORCH = "PYTORCH" - CUDA = "CUDA" + CUTLASS = "CUTLASS" diff --git a/tzrec/ops/_cuda/__init__.py b/tzrec/ops/_cuda/__init__.py new file mode 100644 index 00000000..47d5389a --- /dev/null +++ b/tzrec/ops/_cuda/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tzrec/ops/_cuda/cutlass_hstu_attention.py b/tzrec/ops/_cuda/cutlass_hstu_attention.py new file mode 100644 index 00000000..d3eb4a8d --- /dev/null +++ b/tzrec/ops/_cuda/cutlass_hstu_attention.py @@ -0,0 +1,379 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# We use the low-level torch.library.define / torch.library.impl / +# register_fake API here instead of @torch.library.custom_op on purpose: +# custom_op wraps the user function in @torch.compiler.disable AND adds an +# autograd_impl / forward_no_grad dispatch layer (even without +# register_autograd). The combination of those wrappers with AOT-Inductor's +# compiled-model dispatch deadlocks when the predict pipeline calls the AOT +# model from multiple worker threads. Confirmed via py-spy dump showing two +# `_forward_loop` threads stuck inside the AOTI `__call__`, one of them +# blocked inside `_cutlass_hstu_mha_fwd` and the other waiting at the entry +# of the AOTI `__call__`. Using the low-level API gives us a single-layer +# dispatch (just our impl) and the deadlock disappears. + +from typing import List, Optional + +import torch + +_LIB = torch.library.Library("tzrec", "FRAGMENT") + +_FWD_SCHEMA = ( + "cutlass_hstu_mha_fwd(" + "Tensor q, Tensor k, Tensor v, Tensor cu_seqlens, " + "SymInt max_seq_len, SymInt scaling_seqlen, " + "Tensor? num_contexts, Tensor? num_targets, " + "SymInt target_group_size, SymInt window_size_left, " + "SymInt window_size_right, float alpha" + ") -> Tensor" +) +_BWD_SCHEMA = ( + "cutlass_hstu_mha_bwd(" + "Tensor dout, Tensor q, Tensor k, Tensor v, Tensor cu_seqlens, " + "Tensor? num_contexts, Tensor? num_targets, " + "SymInt max_seq_len, SymInt scaling_seqlen, " + "SymInt target_group_size, SymInt window_size_left, " + "SymInt window_size_right, float alpha" + ") -> Tensor[]" +) + +_LIB.define(_FWD_SCHEMA) +_LIB.define(_BWD_SCHEMA) + + +def _cutlass_hstu_mha_fwd_impl( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seq_len: int, + scaling_seqlen: int, + num_contexts: Optional[torch.Tensor], + num_targets: Optional[torch.Tensor], + target_group_size: int, + window_size_left: int, + window_size_right: int, + alpha: float, +) -> torch.Tensor: + import hstu_attn_2_cuda as hstu_attn_cuda + + num_heads = q.size(1) + head_dim = q.size(2) + out, _ = hstu_attn_cuda.varlen_fwd( + q, + k, + v, + cu_seqlens, + cu_seqlens, + max_seq_len, + max_seq_len, + scaling_seqlen, + num_contexts, + num_targets, + target_group_size, + window_size_left, + window_size_right, + alpha, + None, # rab + None, # func + None, # kv_cache + None, # page_offsets + None, # page_ids + None, # last_page_lens + None, # cu_seqlens_t + ) + return out[:, :, :head_dim].reshape(-1, num_heads, head_dim) + + +def _cutlass_hstu_mha_bwd_impl( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + num_contexts: Optional[torch.Tensor], + num_targets: Optional[torch.Tensor], + max_seq_len: int, + scaling_seqlen: int, + target_group_size: int, + window_size_left: int, + window_size_right: int, + alpha: float, +) -> List[torch.Tensor]: + import hstu_attn_2_cuda as hstu_attn_cuda + + num_heads = q.size(1) + head_dim = q.size(2) + dq, dk, dv, _ = hstu_attn_cuda.varlen_bwd( + dout.view(-1, num_heads, head_dim), + q, + k, + v, + None, + None, + None, + cu_seqlens, + cu_seqlens, + max_seq_len, + max_seq_len, + scaling_seqlen, + num_contexts, + num_targets, + target_group_size, + window_size_left, + window_size_right, + alpha, + None, # rab_padded + False, # has_drab + None, # func + False, # deterministic + ) + return [dq, dk, dv] + + +def _cutlass_hstu_mha_fwd_meta( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seq_len: int, + scaling_seqlen: int, + num_contexts: Optional[torch.Tensor], + num_targets: Optional[torch.Tensor], + target_group_size: int, + window_size_left: int, + window_size_right: int, + alpha: float, +) -> torch.Tensor: + # Output shape is (total, nheads, hidden_dim); that matches v, not q. + # Under the current attention_dim == hidden_dim constraint q and v are + # the same shape, but keying the fake off v makes the meta robust if + # that constraint is relaxed later. + return torch.empty_like(v) + + +def _cutlass_hstu_mha_bwd_meta( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + num_contexts: Optional[torch.Tensor], + num_targets: Optional[torch.Tensor], + max_seq_len: int, + scaling_seqlen: int, + target_group_size: int, + window_size_left: int, + window_size_right: int, + alpha: float, +) -> List[torch.Tensor]: + return [torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)] + + +_LIB.impl("cutlass_hstu_mha_fwd", _cutlass_hstu_mha_fwd_impl, "CUDA") +_LIB.impl("cutlass_hstu_mha_bwd", _cutlass_hstu_mha_bwd_impl, "CUDA") +torch.library.register_fake("tzrec::cutlass_hstu_mha_fwd")(_cutlass_hstu_mha_fwd_meta) +torch.library.register_fake("tzrec::cutlass_hstu_mha_bwd")(_cutlass_hstu_mha_bwd_meta) + + +class _CutlassHstuMhaFunction(torch.autograd.Function): + """Python autograd.Function wrapping the cutlass low-level torch ops. + + Backward is implemented at the Python autograd level (not via + ``register_autograd`` on the op) so that the inference dispatch goes + straight to the registered impl, avoiding the autograd_impl / + forward_no_grad wrapper layer that triggers the multi-threaded AOTI + deadlock under predict workloads. + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seq_len: int, + scaling_seqlen: int, + num_contexts: Optional[torch.Tensor], + num_targets: Optional[torch.Tensor], + target_group_size: int, + window_size_left: int, + window_size_right: int, + alpha: float, + ) -> torch.Tensor: + out = torch.ops.tzrec.cutlass_hstu_mha_fwd( + q, + k, + v, + cu_seqlens, + max_seq_len, + scaling_seqlen, + num_contexts, + num_targets, + target_group_size, + window_size_left, + window_size_right, + alpha, + ) + ctx.save_for_backward(q, k, v, cu_seqlens, num_contexts, num_targets) + ctx.max_seq_len = max_seq_len + ctx.scaling_seqlen = scaling_seqlen + ctx.target_group_size = target_group_size + ctx.window_size_left = window_size_left + ctx.window_size_right = window_size_right + ctx.alpha = alpha + return out + + @staticmethod + def backward(ctx, grad_output): # type: ignore[override] + q, k, v, cu_seqlens, num_contexts, num_targets = ctx.saved_tensors + dq, dk, dv = torch.ops.tzrec.cutlass_hstu_mha_bwd( + grad_output.contiguous(), + q, + k, + v, + cu_seqlens, + num_contexts, + num_targets, + ctx.max_seq_len, + ctx.scaling_seqlen, + ctx.target_group_size, + ctx.window_size_left, + ctx.window_size_right, + ctx.alpha, + ) + return ( + dq, + dk, + dv, + None, # cu_seqlens + None, # max_seq_len + None, # scaling_seqlen + None, # num_contexts + None, # num_targets + None, # target_group_size + None, # window_size_left + None, # window_size_right + None, # alpha + ) + + +@torch.fx.wrap +def cutlass_hstu_mha( + max_seq_len: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + causal: bool = True, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, +) -> torch.Tensor: + """CUTLASS-based HSTU multi-head attention. + + The CUTLASS kernel uses int32 cu_seqlens internally, so the cumulative + sum ``seq_offsets[-1]`` (total token count in the batch) must fit in + int32 (< 2**31 ≈ 2.1B tokens). + + Args: + max_seq_len: maximum sequence length in the batch. + alpha: scaling factor for attention scores. + q: query tensor of shape (total, nheads, attn_dim). + k: key tensor of shape (total, nheads, attn_dim). + v: value tensor of shape (total, nheads, hidden_dim). + seq_offsets: cumulative sequence offsets of shape (batch_size + 1,). + causal: whether to apply causal masking. + num_targets: number of target tokens per batch element. + max_attn_len: maximum attention window length (0 means unlimited). + contextual_seq_len: number of contextual tokens per sequence. + + Returns: + output tensor of shape (total, nheads, hidden_dim). + """ + if q.shape[2] != v.shape[2]: + raise ValueError( + f"CUTLASS hstu_attn requires attention_dim == hidden_dim, " + f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}" + ) + if q.dtype not in (torch.float16, torch.bfloat16): + raise ValueError( + f"CUTLASS hstu_attn only supports fp16 and bf16, got {q.dtype}. " + f"Set train_config.mixed_precision to 'BF16' or 'FP16'." + ) + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + # ``.to(torch.int32)`` already allocates a fresh contiguous tensor when + # the source isn't contiguous, so an explicit ``.contiguous()`` on + # ``seq_offsets`` is redundant. Total token count must fit in int32 + # (~2.1B), which is well beyond any realistic batch. + cu_seqlens = seq_offsets.to(torch.int32) + + if causal: + if max_attn_len > 0: + window_size_left, window_size_right = max_attn_len, 0 + else: + window_size_left, window_size_right = -1, 0 + else: + window_size_left, window_size_right = -1, -1 + + num_contexts_tensor: Optional[torch.Tensor] = None + if contextual_seq_len > 0: + batch_size = seq_offsets.size(0) - 1 + num_contexts_tensor = torch.full( + (batch_size,), + contextual_seq_len, + dtype=torch.int32, + device=q.device, + ) + + num_targets_int32: Optional[torch.Tensor] = None + if num_targets is not None: + num_targets_int32 = num_targets.to(torch.int32) + + # In autograd-enabled context (training), go through the + # _CutlassHstuMhaFunction so backward is wired up. Under no_grad / + # inference / FX-traced graphs, we still call the underlying op + # directly, which dispatches straight to the CUDA impl. + if torch.is_grad_enabled() and any(t.requires_grad for t in (q, k, v)): + return _CutlassHstuMhaFunction.apply( + q, + k, + v, + cu_seqlens, + max_seq_len, + max_seq_len, # scaling_seqlen + num_contexts_tensor, + num_targets_int32, + 1, # target_group_size + window_size_left, + window_size_right, + alpha, + ) + return torch.ops.tzrec.cutlass_hstu_mha_fwd( + q, + k, + v, + cu_seqlens, + max_seq_len, + max_seq_len, # scaling_seqlen + num_contexts_tensor, + num_targets_int32, + 1, # target_group_size + window_size_left, + window_size_right, + alpha, + ) diff --git a/tzrec/ops/benchmarks/hstu_attention_bench.py b/tzrec/ops/benchmarks/hstu_attention_bench.py index 7e22d37c..642299be 100644 --- a/tzrec/ops/benchmarks/hstu_attention_bench.py +++ b/tzrec/ops/benchmarks/hstu_attention_bench.py @@ -44,6 +44,8 @@ def _get_kernel(provider: str) -> Kernel: return Kernel.TRITON elif provider == "pytorch": return Kernel.PYTORCH + elif provider == "cutlass": + return Kernel.CUTLASS else: raise ValueError(f"Unknown provider {provider}") @@ -93,6 +95,7 @@ def _flops( @click.option("--bench-backward", type=bool, default=True) @click.option("--bench-forward", type=bool, default=True) @click.option("--bench-pytorch", type=bool, default=False) +@click.option("--bench-cutlass", type=bool, default=False) @click.option("--report-flops", type=bool, default=False) @click.option("--return-result", type=bool, default=False) @click.option("--max-attn-len", type=int, default=0) @@ -112,6 +115,7 @@ def main( # noqa: C901 bench_backward: bool, bench_forward: bool, bench_pytorch: bool, + bench_cutlass: bool, report_flops: bool, return_result: bool, max_attn_len: int, @@ -132,6 +136,10 @@ def main( # noqa: C901 line_vals = ["triton"] line_names = ["Triton"] styles = [("red", "-")] + if bench_cutlass: + line_vals.append("cutlass") + line_names.append("CUTLASS") + styles.append(("blue", "-")) if bench_pytorch: line_vals.append("pytorch") line_names.append("PyTorch") @@ -252,7 +260,7 @@ def _bench_hstu_attention( q = q.requires_grad_(True) k = k.requires_grad_(True) v = v.requires_grad_(True) - assert provider in ["triton", "pytorch"] + assert provider in ["triton", "pytorch", "cutlass"] if has_delta_q: fn = lambda: delta_hstu_mha( # noqa E731 max_seq_len=seq_len, diff --git a/tzrec/ops/hstu_attention.py b/tzrec/ops/hstu_attention.py index 8317ade9..c8fc297c 100644 --- a/tzrec/ops/hstu_attention.py +++ b/tzrec/ops/hstu_attention.py @@ -24,6 +24,9 @@ pytorch_hstu_mha, ) from tzrec.ops.utils import switch_to_contiguous_if_needed +from tzrec.utils.logging_util import logger + +_cutlass_local_window_fallback_warned = False def hstu_mha( @@ -54,8 +57,48 @@ def hstu_mha( torch._assert(v.shape[1] == H, "wrong v shape[1]") torch._assert(causal, "only support causal attention") - if kernel in [Kernel.TRITON]: - if not is_fx_tracing() and kernel == Kernel.TRITON: + if kernel == Kernel.CUTLASS: + # CUTLASS kernel does not support combining local window attention + # (max_attn_len > 0) with context/target masking, fall back to Triton. + _has_local_window = max_attn_len > 0 + _has_ctx_or_tgt = contextual_seq_len > 0 or num_targets is not None + if _has_local_window and _has_ctx_or_tgt: + global _cutlass_local_window_fallback_warned + if not _cutlass_local_window_fallback_warned: + logger.warning( + "CUTLASS kernel does not support combining local " + "window attention (max_attn_len > 0) with " + "context/target masking, falling back to Triton." + ) + _cutlass_local_window_fallback_warned = True + kernel = Kernel.TRITON + + if kernel == Kernel.CUTLASS: + # cutlass_hstu_mha is @torch.fx.wrap'd; FX treats it as a leaf so + # we call it directly without going through the contiguous/assertion + # preprocessing block below (which has control flow that would + # break FX symbolic tracing). The CUTLASS kernel requires fp16/bf16 + # inputs; we rely on the DenseAutocastWrapper applied in + # tzrec/acc/aot_utils.py and trt_utils.py (driven by + # export_config.mixed_precision / train_config.mixed_precision) to + # ensure q/k/v are bf16/fp16 when reaching this dispatch. + from tzrec.ops._cuda.cutlass_hstu_attention import cutlass_hstu_mha + + return cutlass_hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + causal=causal, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + ) + + if kernel == Kernel.TRITON: + if not is_fx_tracing(): torch._assert(q.is_cuda, "q must be CUDA tensor") torch._assert(k.is_cuda, "k must be CUDA tensor") torch._assert(v.is_cuda, "v must be CUDA tensor") @@ -117,6 +160,8 @@ def delta_hstu_mha( kernel: Kernel = Kernel.PYTORCH, enable_tma: bool = False, ) -> torch.Tensor: + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON L, H, D = delta_q.shape B = seq_offsets.size(0) - 1 DeltaSize = L // B # NOQA @@ -129,8 +174,9 @@ def delta_hstu_mha( torch._assert(k.shape[2] == D, "wrong k shape[2]") torch._assert(v.dim() == 3, "v must be 3-D") torch._assert(v.shape[1] == H, "wrong v shape[1]") + if kernel in [Kernel.TRITON]: - if not is_fx_tracing() and kernel == Kernel.TRITON: + if not is_fx_tracing(): torch._assert(delta_q.is_cuda, "q must be CUDA tensor") torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor") if num_targets is not None: diff --git a/tzrec/ops/hstu_attention_test.py b/tzrec/ops/hstu_attention_test.py index c2fac5dc..ccb9d4ca 100644 --- a/tzrec/ops/hstu_attention_test.py +++ b/tzrec/ops/hstu_attention_test.py @@ -487,6 +487,42 @@ def test_cache( real_delta_out, ) + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(4, 8), + heads=st.integers(1, 4), + max_uih_len=st.sampled_from([20, 100, 128, 256]), + max_targets=st.sampled_from([20, 512]), + attn_dim=st.sampled_from([32, 64, 128]), + causal=st.sampled_from([True]), + has_multiple_targets=st.sampled_from([True, False]), + dtype=st.sampled_from(get_test_dtypes([torch.bfloat16])), + has_max_attn_len=st.sampled_from([True, False]), + contextual_seq_len=st.sampled_from([0, 10]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + # pyre-ignore[2] + def test_attn_cutlass(self, *args, **kwargs) -> None: + hidden_dim = kwargs.pop("attn_dim") + test_attn( + *args, + **kwargs, + attn_dim=hidden_dim, + hidden_dim=hidden_dim, + test_backward=True, + ref_kernel=Kernel.PYTORCH, + real_kernel=Kernel.CUTLASS, + ) + + # NOTE: no ``test_delta_attn_cutlass`` — ``delta_hstu_mha`` has no + # CUTLASS implementation and falls back to Triton internally. The + # delta/cached path is already covered by ``test_delta_attn_triton``. + if __name__ == "__main__": unittest.main() diff --git a/tzrec/ops/hstu_compute.py b/tzrec/ops/hstu_compute.py index 962cf619..891999ad 100644 --- a/tzrec/ops/hstu_compute.py +++ b/tzrec/ops/hstu_compute.py @@ -40,6 +40,8 @@ def hstu_compute_uqvk( uvqk_bias: torch.Tensor, kernel: Kernel = Kernel.PYTORCH, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON normed_x = layer_norm( x, weight=norm_weight, @@ -87,6 +89,8 @@ def hstu_compute_output( recompute_y_in_backward: bool, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON if kernel == Kernel.TRITON: from tzrec.ops._triton.triton_hstu_linear import ( triton_hstu_compute_output, diff --git a/tzrec/ops/jagged_tensors.py b/tzrec/ops/jagged_tensors.py index 9e26b3c4..b484a512 100644 --- a/tzrec/ops/jagged_tensors.py +++ b/tzrec/ops/jagged_tensors.py @@ -35,6 +35,8 @@ def concat_2D_jagged( offsets_right: Optional[torch.Tensor] = None, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON if not is_fx_tracing(): torch._assert(values_left.dim() == 2, "values_left must be 2D") torch._assert(values_right.dim() == 2, "values_right must be 2D") @@ -75,6 +77,8 @@ def split_2D_jagged( offsets_right: Optional[torch.Tensor] = None, kernel: Kernel = Kernel.PYTORCH, ) -> Tuple[torch.Tensor, torch.Tensor]: + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON if not is_fx_tracing(): torch._assert(values.dim() == 2, "values must be 2D") torch._assert( @@ -137,6 +141,8 @@ def jagged_dense_bmm_broadcast_add( jagged has shape (sum_B(M_i), K), dense has shape (B, K, N), and bias has shape (B, N), out has shape (sum_B(M_i), N) """ + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON if not is_fx_tracing(): _, K = jagged.shape B, _, N = dense.shape diff --git a/tzrec/ops/layer_norm.py b/tzrec/ops/layer_norm.py index de61e942..4099a2ef 100644 --- a/tzrec/ops/layer_norm.py +++ b/tzrec/ops/layer_norm.py @@ -32,6 +32,8 @@ def layer_norm( eps: float = 1e-5, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON if kernel == Kernel.TRITON: if not is_fx_tracing(): torch._assert(not x.is_cpu, "x must not be cpu tensor") @@ -60,6 +62,8 @@ def rms_norm( eps: float = 1e-5, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON if kernel == Kernel.TRITON: if not is_fx_tracing(): torch._assert(not x.is_cpu, "x must not be cpu tensor") @@ -87,6 +91,8 @@ def swish_layer_norm( eps: float = 1e-5, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON if kernel == Kernel.TRITON: if not is_fx_tracing(): torch._assert(not x.is_cpu, "x must not be cpu tensor") diff --git a/tzrec/ops/mm.py b/tzrec/ops/mm.py index 335a5d41..e2ed2570 100644 --- a/tzrec/ops/mm.py +++ b/tzrec/ops/mm.py @@ -24,6 +24,8 @@ def addmm( mat2: torch.Tensor, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON if kernel == Kernel.TRITON: from tzrec.ops._triton.triton_addmm import triton_addmm diff --git a/tzrec/ops/position.py b/tzrec/ops/position.py index d2a2b3b8..618831ef 100644 --- a/tzrec/ops/position.py +++ b/tzrec/ops/position.py @@ -53,6 +53,8 @@ def add_positional_embeddings( interleave_targets: bool, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON high_inds = _get_high_inds( seq_lengths, position_embeddings_weight, num_targets, interleave_targets ) @@ -104,6 +106,8 @@ def add_timestamp_positional_embeddings( kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: assert time_bucket_fn in ["sqrt", "log"] + if kernel == Kernel.CUTLASS: + kernel = Kernel.TRITON seq_embeddings = seq_embeddings * alpha if kernel == Kernel.TRITON: from tzrec.ops._triton.triton_position import ( diff --git a/tzrec/protos/export.proto b/tzrec/protos/export.proto index cbbd8a7e..7580e2fe 100644 --- a/tzrec/protos/export.proto +++ b/tzrec/protos/export.proto @@ -10,4 +10,10 @@ message ExportConfig { optional string best_exporter_metric = 2 [default = 'auc']; // metric value the bigger the best optional bool metric_larger_is_better = 3 [default = true]; + // mixed precision mode for inference/export [BF16 | FP16 | ""]. + // If empty, falls back to train_config.mixed_precision. When set, the + // dense sub-graph is wrapped in torch.autocast before torch.export so + // that AOT Inductor captures dtype-promoting casts as a + // wrap_with_autocast HOP. + optional string mixed_precision = 4 [default = '']; } diff --git a/tzrec/protos/model.proto b/tzrec/protos/model.proto index f3c5c2f0..a363077d 100644 --- a/tzrec/protos/model.proto +++ b/tzrec/protos/model.proto @@ -34,7 +34,7 @@ message FeatureGroupConfig { enum Kernel { TRITON = 0; PYTORCH = 1; - CUDA = 2; + CUTLASS = 2; } message ModelConfig { diff --git a/tzrec/tests/configs/dlrm_hstu_cutlass_kuairand_1k.config b/tzrec/tests/configs/dlrm_hstu_cutlass_kuairand_1k.config new file mode 100644 index 00000000..9e8c2008 --- /dev/null +++ b/tzrec/tests/configs/dlrm_hstu_cutlass_kuairand_1k.config @@ -0,0 +1,343 @@ +train_input_path: "data/test/kuairand-1k-train-c4096-s100.parquet" +eval_input_path: "data/test/kuairand-1k-eval-c4096-s100.parquet" +model_dir: "experiments/kuairand/dlrm_hstu" +train_config { + sparse_optimizer { + rowwise_adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 + save_checkpoints_epochs: 1 + mixed_precision: "BF16" +} +data_config { + batch_size: 8 + dataset_type: ParquetDataset + num_workers: 2 + fg_mode: FG_DAG + label_fields: ["cand_seq__action_weight", "cand_seq__watch_time"] +} +feature_configs { + sequence_feature { + sequence_name: "uih_seq" + sequence_length: 8000 + sequence_delim: "|" + features { + id_feature { + feature_name: "video_id" + expression: "item:video_id" + embedding_name: "video_id_emb" + embedding_dim: 32 + num_buckets: 10000000 + data_type: "FP16" + } + } + features { + raw_feature { + feature_name: "action_timestamp" + expression: "user:action_timestamp" + } + } + features { + raw_feature { + feature_name: "action_weight" + expression: "user:action_weight" + } + } + features { + raw_feature { + feature_name: "watch_time" + expression: "user:watch_time" + } + } + } +} +feature_configs { + sequence_feature { + sequence_name: "cand_seq" + sequence_length: 8000 + sequence_delim: "|" + features { + id_feature { + feature_name: "video_id" + expression: "item:video_id" + embedding_name: "video_id_emb" + embedding_dim: 32 + num_buckets: 10000000 + data_type: "FP16" + } + } + features { + raw_feature { + feature_name: "query_time" + expression: "user:query_time" + } + } + } +} +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + embedding_dim: 32 + num_buckets: 10000000 + data_type: "FP16" + } +} +feature_configs { + id_feature { + feature_name: "user_active_degree" + expression: "user:user_active_degree" + embedding_dim: 32 + num_buckets: 8 + data_type: "FP16" + } +} +feature_configs { + id_feature { + feature_name: "follow_user_num_range" + expression: "user:follow_user_num_range" + embedding_dim: 32 + num_buckets: 9 + data_type: "FP16" + } +} +feature_configs { + id_feature { + feature_name: "fans_user_num_range" + expression: "user:fans_user_num_range" + embedding_dim: 32 + num_buckets: 9 + data_type: "FP16" + } +} +feature_configs { + id_feature { + feature_name: "friend_user_num_range" + expression: "user:friend_user_num_range" + embedding_dim: 32 + num_buckets: 8 + data_type: "FP16" + } +} +feature_configs { + id_feature { + feature_name: "register_days_range" + expression: "user:register_days_range" + embedding_dim: 32 + num_buckets: 8 + data_type: "FP16" + } +} +model_config { + feature_groups { + group_name: "contextual" + feature_names: "user_id" + feature_names: "user_active_degree" + feature_names: "follow_user_num_range" + feature_names: "fans_user_num_range" + feature_names: "friend_user_num_range" + feature_names: "register_days_range" + group_type: DEEP + } + feature_groups { + group_name: "uih" + feature_names: "uih_seq__video_id" + group_type: JAGGED_SEQUENCE + } + feature_groups { + group_name: "candidate" + feature_names: "cand_seq__video_id" + group_type: JAGGED_SEQUENCE + } + feature_groups { + group_name: "uih_action" + feature_names: "uih_seq__action_weight" + group_type: JAGGED_SEQUENCE + } + feature_groups { + group_name: "uih_watchtime" + feature_names: "uih_seq__watch_time" + group_type: JAGGED_SEQUENCE + } + feature_groups { + group_name: "uih_timestamp" + feature_names: "uih_seq__action_timestamp" + group_type: JAGGED_SEQUENCE + } + feature_groups { + group_name: "candidate_timestamp" + feature_names: "cand_seq__query_time" + group_type: JAGGED_SEQUENCE + } + dlrm_hstu { + hstu { + stu { + embedding_dim: 512 + num_heads: 4 + hidden_dim: 128 + attention_dim: 128 + output_dropout_ratio: 0.1 + use_group_norm: true + } + input_dropout_ratio: 0.2 + attn_num_layers: 3 + positional_encoder { + num_position_buckets: 8192 + num_time_buckets: 2048 + use_time_encoding: true + } + input_preprocessor { + contextual_preprocessor { + action_encoder { + simple_action_encoder { + action_embedding_dim: 8 + action_weights: [1, 2, 4, 8, 16, 32, 64, 128] + } + } + action_mlp { + simple_mlp { + hidden_dim: 256 + } + } + content_encoder { + slice_content_encoder {} + } + content_mlp { + simple_mlp { + hidden_dim: 256 + } + } + } + } + output_postprocessor { + timestamp_layernorm_postprocessor { + time_duration_period_units: [3600, 86400] + time_duration_units_per_period: [24, 7] + } + } + } + fusion_mtl_tower { + mlp { + hidden_units: 512 + activation: "nn.SiLU" + use_ln: true + } + task_configs { + task_name: "is_click" + label_name: "cand_seq__action_weight" + task_bitmask: 1 + losses { + binary_cross_entropy {} + } + metrics { + auc {} + } + } + task_configs { + task_name: "is_like" + label_name: "cand_seq__action_weight" + task_bitmask: 2 + num_class: 2 + losses { + jrc_loss { + session_name: "user_id" + } + } + metrics { + grouped_auc { + grouping_key: "user_id" + } + } + } + task_configs { + task_name: "is_follow" + label_name: "cand_seq__action_weight" + task_bitmask: 4 + losses { + binary_cross_entropy {} + } + metrics { + auc {} + } + } + task_configs { + task_name: "is_comment" + label_name: "cand_seq__action_weight" + task_bitmask: 8 + losses { + binary_cross_entropy {} + } + metrics { + auc {} + } + } + task_configs { + task_name: "is_forward" + label_name: "cand_seq__action_weight" + task_bitmask: 16 + losses { + binary_cross_entropy {} + } + metrics { + auc {} + } + } + task_configs { + task_name: "is_hate" + label_name: "cand_seq__action_weight" + task_bitmask: 32 + losses { + binary_cross_entropy {} + } + metrics { + auc {} + } + } + task_configs { + task_name: "long_view" + label_name: "cand_seq__action_weight" + task_bitmask: 64 + losses { + binary_cross_entropy {} + } + metrics { + auc {} + } + } + task_configs { + task_name: "is_profile_enter" + label_name: "cand_seq__action_weight" + task_bitmask: 128 + losses { + binary_cross_entropy {} + } + metrics { + auc {} + } + } + task_configs { + task_name: "watchtime" + label_name: "cand_seq__watch_time" + losses { + l2_loss {} + } + metrics { + mean_absolute_error {} + } + } + }, + max_seq_len: 4096 + } + kernel: CUTLASS +} diff --git a/tzrec/tests/rank_integration_test.py b/tzrec/tests/rank_integration_test.py index 7170736f..d043fb77 100644 --- a/tzrec/tests/rank_integration_test.py +++ b/tzrec/tests/rank_integration_test.py @@ -956,6 +956,40 @@ def test_rank_dlrm_hstu_train_eval_export(self): os.path.exists(os.path.join(self.test_dir, "export/aoti_model.pt2")) ) + @unittest.skipIf(*gpu_unavailable) + def test_rank_dlrm_hstu_cutlass_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/dlrm_hstu_cutlass_kuairand_1k.config", self.test_dir + ) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), + self.test_dir, + env_str="ENABLE_AOT=1", + ) + predict_output_path = os.path.join(self.test_dir, "predict_result") + if self.success: + self.success = utils.test_predict( + os.path.join(self.test_dir, "export"), + predict_input_path="data/test/kuairand-1k-eval-c4096-s100.parquet", + predict_output_path=predict_output_path, + reserved_columns="user_id,cand_seq__video_id", + output_columns="", + test_dir=self.test_dir, + # The cutlass custom op path is not safe to call concurrently + # through an AOT-Inductor compiled model yet (hstu_attn_cuda's + # pybind11 binding does not release the GIL and the AOTI + # runtime then deadlocks between the two predict forward + # worker threads). Restrict predict to a single worker until + # the underlying multi-threading issue is addressed upstream. + predict_threads=1, + ) + self.assertTrue(self.success) + @unittest.skipIf( gpu_unavailable[0] or not dynamicemb_util.has_dynamicemb, "dynamicemb not available.", diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index e7021929..b0ac5b7e 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -192,20 +192,32 @@ def export_model_normal( model.eval() data = batch.to_dict(sparse_dtype=torch.int64) - if acc_utils.is_trt(): + mixed_precision = acc_utils.resolve_mixed_precision(pipeline_config) + autocast_dtype = acc_utils.mixed_precision_to_dtype(mixed_precision) + if acc_utils.is_trt() or acc_utils.is_aot(): data = OrderedDict(sorted(data.items())) - result = model(data, "cuda:0") - result_info = {k: (v.size(), v.dtype) for k, v in result.items()} - logger.info(f"Model Outputs: {result_info}") - sparse, dense, _ = split_model(data, model, save_dir) - export_model_trt(sparse, dense, data, save_dir) - elif acc_utils.is_aot(): - data = OrderedDict(sorted(data.items())) - result = model(data, "cuda:0") + with torch.amp.autocast( + device_type="cuda", + dtype=autocast_dtype, + enabled=autocast_dtype is not None, + ): + result = model(data, "cuda:0") result_info = {k: (v.size(), v.dtype) for k, v in result.items()} logger.info(f"Model Outputs: {result_info}") sparse, dense, meta_info = split_model(data, model, save_dir) - export_model_aot(sparse, dense, data, meta_info, save_dir) + if acc_utils.is_trt(): + export_model_trt( + sparse, dense, data, save_dir, mixed_precision=mixed_precision + ) + else: + export_model_aot( + sparse, + dense, + data, + meta_info, + save_dir, + mixed_precision=mixed_precision, + ) else: result = model(data) result_info = {k: (v.size(), v.dtype) for k, v in result.items()}