Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e624365
[feat] add CUTLASS kernel backend for HSTU attention
tiankongdeguiji Apr 3, 2026
410210f
[fix] handle CUTLASS kernel unsupported param combinations
tiankongdeguiji Apr 3, 2026
adb3c60
[feat] add hstu_attn wheel to extra requirements
tiankongdeguiji Apr 3, 2026
322aa91
[feat] add hstu_attn wheels for cp310/cp312
tiankongdeguiji Apr 3, 2026
23dfe50
[docs] move CUTLASS usage guide from FAQ to dlrm_hstu.md
tiankongdeguiji Apr 3, 2026
6fb33eb
[docs] simplify CUTLASS docs into kernel description
tiankongdeguiji Apr 3, 2026
ab97406
[docs] use repo.html style install for hstu_attn like dynamicemb
tiankongdeguiji Apr 3, 2026
06645a1
[docs] add cu126 option for hstu_attn install
tiankongdeguiji Apr 3, 2026
e114ab6
[feat] add logging warning when CUTLASS falls back to Triton
tiankongdeguiji Apr 3, 2026
6028d32
[feat] add logging warning in cutlass_cached_hstu_mha fallback
tiankongdeguiji Apr 3, 2026
d3d079d
[feat] add CUTLASS provider to hstu_attention_bench.py
tiankongdeguiji Apr 3, 2026
785befc
[fix] move CUTLASS fallback logic to dispatch layer, add export test
tiankongdeguiji Apr 3, 2026
07d27e8
[chore] bump version to 1.1.7
tiankongdeguiji Apr 3, 2026
97e3e25
[fix] register cutlass_hstu_mha as torch.library custom_op for AOT ex…
tiankongdeguiji Apr 7, 2026
c437893
[fix] make CUTLASS dispatch FX-safe and handle fp32 inputs during export
tiankongdeguiji Apr 7, 2026
590bb4f
[refactor] use AutocastWrapper to propagate mixed_precision through e…
tiankongdeguiji Apr 7, 2026
9956957
[fix] register cutlass_hstu_mha custom op at aot_utils import time
tiankongdeguiji Apr 7, 2026
9f39c32
[fix] avoid AOTI multi-thread predict deadlock in CUTLASS custom op
tiankongdeguiji Apr 8, 2026
a22f36b
[refactor] drop sparse-side autocast wrap, add export_config.mixed_pr…
tiankongdeguiji Apr 8, 2026
f7d020b
[cleanup] dedup mixed_precision handling, extract helpers, unify TRT/…
tiankongdeguiji Apr 8, 2026
27420b0
Merge remote-tracking branch 'origin/master' into feat/cutlass-hstu-attn
tiankongdeguiji Apr 8, 2026
c4898fe
[chore] use tzrec.oss-accelerate URL for hstu_attn wheel
tiankongdeguiji Apr 8, 2026
f8bd059
[feat] CUTLASS kernel falls back to TRITON for non-attention sub-ops
tiankongdeguiji Apr 8, 2026
ef3e24e
[refactor] CUTLASS fallback handled per-op at each op entry, drop sub…
tiankongdeguiji Apr 8, 2026
55b33e4
[cleanup] address review nits in cutlass_hstu_attention / hstu_attention
tiankongdeguiji Apr 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/models/dlrm_hstu.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ model_config {
- metrics: 任务指标
- max_seq_len: 最大序列长度

- kernel: 算子实现,可选TRITON/PYTORCH,TRITON通常比PYTORCH快2-3x,节省2-3x显存
- kernel: 算子实现,可选TRITON/PYTORCH/CUTLASS

- TRITON: 基于Triton的实现,通常比PYTORCH快2-3x,节省2-3x显存
- CUTLASS: 基于CUTLASS的CUDA融合算子实现,需安装hstu_attn包(DEVICE可选cu126/cu129:`pip install hstu_attn-0.1.0+bea6b4b.${DEVICE} -f https://tzrec.oss-accelerate.aliyuncs.com/third_party/hstu/${DEVICE}/repo.html`),要求`attention_dim`等于`hidden_dim`,支持Ampere/Ada/Hopper GPU
- PYTORCH: 纯PyTorch实现,兼容性最好

### MTGR Style 配置方式

Expand Down
3 changes: 3 additions & 0 deletions requirements/extra.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260407.97b80bf.cu129-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260407.97b80bf.cu129-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260407.97b80bf.cu129-cp312-cp312-linux_x86_64.whl ; python_version=="3.12"
hstu_attn @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
hstu_attn @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
hstu_attn @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp312-cp312-linux_x86_64.whl ; python_version=="3.12"
torch_fx_tool @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/rtp/torch_fx_tool-0.0.1%2B20251201.8c109c4-py3-none-any.whl
32 changes: 29 additions & 3 deletions tzrec/acc/aot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,26 @@


import os
from typing import Any, Dict
from typing import Any, Dict, Optional

import torch
from torch import nn

from tzrec.models.model import CombinedModelWrapper
from tzrec.models.model import CombinedModelWrapper, DenseAutocastWrapper
from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import logger

# Eagerly register custom ops referenced by AOT-packaged models so that
# torch._inductor.aoti_load_package() can resolve them by name. AOT packages
# reference ops via their qualified name (e.g. ``tzrec::cutlass_hstu_mha_fwd``)
# and PyTorch only knows about an op once its registering module has been
# imported. Wrap in try/except so this stays optional for environments
# without the corresponding native dependencies installed.
try:
from tzrec.ops._cuda import cutlass_hstu_attention # noqa: F401
except ImportError:
pass


def load_model_aot(model_path: str, device: torch.device) -> CombinedModelWrapper:
"""Load AOTInductor model.
Expand Down Expand Up @@ -49,6 +60,7 @@ def export_model_aot(
data: Dict[str, torch.Tensor],
meta_info: Dict[str, Any],
save_dir: str,
mixed_precision: Optional[str] = None,
) -> str:
"""Export AOTInductor model.

Expand All @@ -58,6 +70,12 @@ def export_model_aot(
data (Dict[str, torch.Tensor]): the test data
meta_info (Dict[str, Any]): split meta info
save_dir (str): model save dir
mixed_precision (Optional[str]): "BF16", "FP16", or None. When set,
the dense sub-graph is wrapped in a DenseAutocastWrapper so that
torch.export captures the autocast region as a wrap_with_autocast
Higher Order Op. The sparse sub-graph is left untouched because
it is only embedding lookups, which don't benefit from AMP and
which would complicate torch.jit.script compilation.
"""
sparse_output, _ = sparse_model(data, "cuda:0")
sparse_model_traced = symbolic_trace(sparse_model)
Expand Down Expand Up @@ -86,12 +104,20 @@ def export_model_aot(

logger.info("dynamic shapes=%s" % dynamic_shapes)

# Wrap the dense module so torch.export captures the autocast region
# as a `wrap_with_autocast` HOP that AOT Inductor lowers correctly.
dense_to_export: nn.Module = dense_model
if mixed_precision:
dense_to_export = DenseAutocastWrapper(dense_model, mixed_precision)

# pre_hook requires running arbitrary code at runtime
with torch._inductor.config.patch(
{"unsafe_ignore_unsupported_triton_autotune_args": True}
):
exported_pg = torch.export.export(
dense_model, args=(sparse_output,), dynamic_shapes=(dynamic_shapes,)
dense_to_export,
args=(sparse_output,),
dynamic_shapes=(dynamic_shapes,),
)
# AsserScalar codegen is not correct.
with torch._inductor.config.patch(
Expand Down
12 changes: 10 additions & 2 deletions tzrec/acc/trt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.profiler import ProfilerActivity, profile, record_function

from tzrec.acc.utils import get_max_export_batch_size, is_debug_trt
from tzrec.models.model import CombinedModelWrapper
from tzrec.models.model import CombinedModelWrapper, DenseAutocastWrapper
from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import logger

Expand Down Expand Up @@ -104,6 +104,7 @@ def export_model_trt(
dense_model: nn.Module,
data: Dict[str, torch.Tensor],
save_dir: str,
mixed_precision: Optional[str] = None,
) -> None:
"""Export trt model.

Expand All @@ -112,6 +113,10 @@ def export_model_trt(
dense_model (nn.Module): the dense part
data (Dict[str, torch.Tensor]): the test data
save_dir (str): model save dir
mixed_precision (Optional[str]): "BF16", "FP16", or None. When set,
only the dense sub-graph is wrapped in a DenseAutocastWrapper
before torch.export. The sparse sub-graph is only embedding
lookups and doesn't benefit from AMP.
"""
emb_ebc, _ = sparse_model(data, "cuda:0")
sparse_model_traced = symbolic_trace(sparse_model)
Expand Down Expand Up @@ -153,8 +158,11 @@ def export_model_trt(
for i, k in enumerate(key_list):
dynamic_shapes[dense_arg_name].update({k: dynamic_shapes_list[i]})

dense_to_export: nn.Module = dense_layer
if mixed_precision:
dense_to_export = DenseAutocastWrapper(dense_layer, mixed_precision)
exp_program = torch.export.export(
dense_layer, (emb_ebc,), dynamic_shapes=dynamic_shapes
dense_to_export, (emb_ebc,), dynamic_shapes=dynamic_shapes
)
dense_layer_trt = trt_convert(exp_program, (emb_ebc,))
# logger.info("dense trt res: %s", dense_layer_trt(emb_ebc))
Expand Down
38 changes: 37 additions & 1 deletion tzrec/acc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
from typing import Dict
from typing import Dict, Optional

import torch

from tzrec.protos.pipeline_pb2 import EasyRecConfig
from tzrec.protos.train_pb2 import TrainConfig


Expand Down Expand Up @@ -165,6 +166,41 @@ def ec_quant_dtype() -> torch.dtype:
return _quant_str_to_dtype[quant_dtype_str]


_MIXED_PRECISION_TO_DTYPE: Dict[str, torch.dtype] = {
"BF16": torch.bfloat16,
"FP16": torch.float16,
}


def mixed_precision_to_dtype(mixed_precision: Optional[str]) -> Optional[torch.dtype]:
"""Convert a TrainConfig.mixed_precision string to a torch dtype.

Returns ``None`` when ``mixed_precision`` is ``None`` or empty.
Raises ``ValueError`` on unknown values so typos fail loudly.
"""
if not mixed_precision:
return None
if mixed_precision not in _MIXED_PRECISION_TO_DTYPE:
raise ValueError(
f"Unknown mixed_precision: {mixed_precision}, "
f"available types: {list(_MIXED_PRECISION_TO_DTYPE.keys())}"
)
return _MIXED_PRECISION_TO_DTYPE[mixed_precision]


def resolve_mixed_precision(pipeline_config: EasyRecConfig) -> str:
"""Resolve the mixed_precision mode for export/inference.

Precedence: ``export_config.mixed_precision`` (when set) overrides
``train_config.mixed_precision``. Empty string means no AMP.
"""
if pipeline_config.HasField("export_config"):
export_mp = pipeline_config.export_config.mixed_precision
if export_mp:
return export_mp
return pipeline_config.train_config.mixed_precision


def write_mapping_file_for_input_tile(
state_dict: Dict[str, torch.Tensor], remap_file_path: str
) -> None:
Expand Down
76 changes: 55 additions & 21 deletions tzrec/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from collections import OrderedDict
from queue import Queue
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Final, Iterable, List, Optional, Tuple

import torch
import torchmetrics
Expand All @@ -22,6 +22,7 @@
EmbeddingCollectionInterface,
)

from tzrec.acc import utils as acc_utils
from tzrec.constant import TARGET_REPEAT_INTERLEAVE_KEY
from tzrec.datasets.data_parser import DataParser
from tzrec.datasets.utils import Batch
Expand Down Expand Up @@ -236,16 +237,7 @@ def __init__(
self._device_type = "cpu"
if device is not None:
self._device_type = device.type
if mixed_precision is None or len(mixed_precision) == 0:
self._mixed_dtype = None
elif mixed_precision == "FP16":
self._mixed_dtype = torch.float16
elif mixed_precision == "BF16":
self._mixed_dtype = torch.bfloat16
else:
raise ValueError(
f"mixed_precision should be FP16 or BF16, but got [{mixed_precision}]"
)
self._mixed_dtype = acc_utils.mixed_precision_to_dtype(mixed_precision)
self.pareto = None
if (
hasattr(self.model, "_use_pareto_loss_weight")
Expand Down Expand Up @@ -300,16 +292,7 @@ def __init__(
self._device_type = "cpu"
if device is not None:
self._device_type = device.type
if mixed_precision is None or len(mixed_precision) == 0:
self._mixed_dtype = None
elif mixed_precision == "FP16":
self._mixed_dtype = torch.float16
elif mixed_precision == "BF16":
self._mixed_dtype = torch.bfloat16
else:
raise ValueError(
f"mixed_precision should be FP16 or BF16, but got [{mixed_precision}]"
)
self._mixed_dtype = acc_utils.mixed_precision_to_dtype(mixed_precision)
self._output_cols = output_cols

def forward(
Expand Down Expand Up @@ -389,6 +372,57 @@ def forward(
return self.model.predict(batch)


class DenseAutocastWrapper(nn.Module):
"""Wraps a dense-side module in a torch.autocast context for torch.export.

Used to wrap the dense GraphModule from split_model before passing to
torch.export.export. torch.export captures the `with torch.autocast(...)`
region as a `wrap_with_autocast` Higher Order Op in the exported graph,
which AOT Inductor lowers to proper dtype casts.

Only the dense path needs this wrapper: the sparse path is embedding
lookups which don't have autocast rules and don't benefit from AMP, and
the CUTLASS attention op lives in the dense sub-graph after the
sparse/dense split.

The forward takes a single dict argument matching the dense_gm signature.

Note on the ``_mixed_dtype_id: Final[int]`` encoding: ``torch.jit.script``
only honors ``torch.autocast(dtype=...)`` when ``dtype`` is a compile-time
literal. We therefore store an integer flag (``Final[int]``) set in
``__init__`` and branch on it in ``forward`` so each branch's
``torch.autocast(dtype=torch.bfloat16|torch.float16)`` is a literal.
Storing the dtype directly as an attribute would not be script-friendly.

Args:
inner (nn.Module): inner dense module to wrap.
mixed_precision (Optional[str]): one of "BF16", "FP16", or None.
"""

_mixed_dtype_id: Final[int]

def __init__(self, inner: nn.Module, mixed_precision: Optional[str] = None) -> None:
super().__init__()
self.inner = inner
if mixed_precision == "BF16":
self._mixed_dtype_id = 1
elif mixed_precision == "FP16":
self._mixed_dtype_id = 2
else:
self._mixed_dtype_id = 0

def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Forward through inner module under an autocast context."""
if self._mixed_dtype_id == 1:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
return self.inner(x)
elif self._mixed_dtype_id == 2:
with torch.autocast(device_type="cuda", dtype=torch.float16):
return self.inner(x)
else:
return self.inner(x)


class CombinedModelWrapper(nn.Module):
"""Model inference wrapper for model combined with sparse and dense part.

Expand Down
2 changes: 1 addition & 1 deletion tzrec/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ class Kernel(Enum):

TRITON = "TRITON"
PYTORCH = "PYTORCH"
CUDA = "CUDA"
CUTLASS = "CUTLASS"
10 changes: 10 additions & 0 deletions tzrec/ops/_cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading