Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e624365
[feat] add CUTLASS kernel backend for HSTU attention
tiankongdeguiji Apr 3, 2026
410210f
[fix] handle CUTLASS kernel unsupported param combinations
tiankongdeguiji Apr 3, 2026
adb3c60
[feat] add hstu_attn wheel to extra requirements
tiankongdeguiji Apr 3, 2026
322aa91
[feat] add hstu_attn wheels for cp310/cp312
tiankongdeguiji Apr 3, 2026
23dfe50
[docs] move CUTLASS usage guide from FAQ to dlrm_hstu.md
tiankongdeguiji Apr 3, 2026
6fb33eb
[docs] simplify CUTLASS docs into kernel description
tiankongdeguiji Apr 3, 2026
ab97406
[docs] use repo.html style install for hstu_attn like dynamicemb
tiankongdeguiji Apr 3, 2026
06645a1
[docs] add cu126 option for hstu_attn install
tiankongdeguiji Apr 3, 2026
e114ab6
[feat] add logging warning when CUTLASS falls back to Triton
tiankongdeguiji Apr 3, 2026
6028d32
[feat] add logging warning in cutlass_cached_hstu_mha fallback
tiankongdeguiji Apr 3, 2026
d3d079d
[feat] add CUTLASS provider to hstu_attention_bench.py
tiankongdeguiji Apr 3, 2026
785befc
[fix] move CUTLASS fallback logic to dispatch layer, add export test
tiankongdeguiji Apr 3, 2026
07d27e8
[chore] bump version to 1.1.7
tiankongdeguiji Apr 3, 2026
97e3e25
[fix] register cutlass_hstu_mha as torch.library custom_op for AOT ex…
tiankongdeguiji Apr 7, 2026
c437893
[fix] make CUTLASS dispatch FX-safe and handle fp32 inputs during export
tiankongdeguiji Apr 7, 2026
590bb4f
[refactor] use AutocastWrapper to propagate mixed_precision through e…
tiankongdeguiji Apr 7, 2026
9956957
[fix] register cutlass_hstu_mha custom op at aot_utils import time
tiankongdeguiji Apr 7, 2026
9f39c32
[fix] avoid AOTI multi-thread predict deadlock in CUTLASS custom op
tiankongdeguiji Apr 8, 2026
a22f36b
[refactor] drop sparse-side autocast wrap, add export_config.mixed_pr…
tiankongdeguiji Apr 8, 2026
f7d020b
[cleanup] dedup mixed_precision handling, extract helpers, unify TRT/…
tiankongdeguiji Apr 8, 2026
27420b0
Merge remote-tracking branch 'origin/master' into feat/cutlass-hstu-attn
tiankongdeguiji Apr 8, 2026
c4898fe
[chore] use tzrec.oss-accelerate URL for hstu_attn wheel
tiankongdeguiji Apr 8, 2026
f8bd059
[feat] CUTLASS kernel falls back to TRITON for non-attention sub-ops
tiankongdeguiji Apr 8, 2026
ef3e24e
[refactor] CUTLASS fallback handled per-op at each op entry, drop sub…
tiankongdeguiji Apr 8, 2026
55b33e4
[cleanup] address review nits in cutlass_hstu_attention / hstu_attention
tiankongdeguiji Apr 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/models/dlrm_hstu.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ model_config {
- metrics: 任务指标
- max_seq_len: 最大序列长度

- kernel: 算子实现,可选TRITON/PYTORCH,TRITON通常比PYTORCH快2-3x,节省2-3x显存
- kernel: 算子实现,可选TRITON/PYTORCH/CUTLASS

- TRITON: 基于Triton的实现,通常比PYTORCH快2-3x,节省2-3x显存
- CUTLASS: 基于CUTLASS的CUDA融合算子实现,需安装hstu_attn包(DEVICE可选cu126/cu129:`pip install hstu_attn-0.1.0+bea6b4b.${DEVICE} -f https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/hstu/${DEVICE}/repo.html`),要求`attention_dim`等于`hidden_dim`,支持Ampere/Ada/Hopper GPU
- PYTORCH: 纯PyTorch实现,兼容性最好

### MTGR Style 配置方式

Expand Down
3 changes: 3 additions & 0 deletions requirements/extra.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260331.bea6b4b.cu129-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260331.bea6b4b.cu129-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260331.bea6b4b.cu129-cp312-cp312-linux_x86_64.whl ; python_version=="3.12"
hstu_attn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
hstu_attn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
hstu_attn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp312-cp312-linux_x86_64.whl ; python_version=="3.12"
torch_fx_tool @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/rtp/torch_fx_tool-0.0.1%2B20251201.8c109c4-py3-none-any.whl
2 changes: 1 addition & 1 deletion tzrec/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ class Kernel(Enum):

TRITON = "TRITON"
PYTORCH = "PYTORCH"
CUDA = "CUDA"
CUTLASS = "CUTLASS"
10 changes: 10 additions & 0 deletions tzrec/ops/_cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
177 changes: 177 additions & 0 deletions tzrec/ops/_cuda/cutlass_hstu_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Optional

import torch
from hstu_attn import hstu_attn_varlen_func

logger = logging.getLogger(__name__)

_triton_fallback_warned = False
_cached_fallback_warned = False


def _needs_triton_fallback(
max_attn_len: int,
contextual_seq_len: int,
num_targets: Optional[torch.Tensor],
) -> bool:
"""Check if we need to fall back to triton.

The CUTLASS kernel does not support combining local window attention
(max_attn_len > 0) with context or target masking.
"""
global _triton_fallback_warned
has_local_window = max_attn_len > 0
has_context_or_target = contextual_seq_len > 0 or num_targets is not None
needs_fallback = has_local_window and has_context_or_target
if needs_fallback and not _triton_fallback_warned:
logger.warning(
"CUTLASS kernel does not support combining local window attention "
"(max_attn_len > 0) with context/target masking, "
"falling back to Triton kernel."
)
_triton_fallback_warned = True
return needs_fallback


@torch.fx.wrap
def cutlass_hstu_mha(
max_seq_len: int,
alpha: float,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_offsets: torch.Tensor,
causal: bool = True,
num_targets: Optional[torch.Tensor] = None,
max_attn_len: int = 0,
contextual_seq_len: int = 0,
) -> torch.Tensor:
"""CUTLASS-based HSTU multi-head attention.

Args:
max_seq_len: maximum sequence length in the batch.
alpha: scaling factor for attention scores.
q: query tensor of shape (total, nheads, attn_dim).
k: key tensor of shape (total, nheads, attn_dim).
v: value tensor of shape (total, nheads, hidden_dim).
seq_offsets: cumulative sequence offsets of shape (batch_size + 1,).
causal: whether to apply causal masking.
num_targets: number of target tokens per batch element.
max_attn_len: maximum attention window length (0 means unlimited).
contextual_seq_len: number of contextual tokens per sequence.

Returns:
output tensor of shape (total, nheads, hidden_dim).
"""
if _needs_triton_fallback(max_attn_len, contextual_seq_len, num_targets):
from tzrec.ops._triton.triton_hstu_attention import triton_hstu_mha

return triton_hstu_mha(
max_seq_len=max_seq_len,
alpha=alpha,
q=q,
k=k,
v=v,
seq_offsets=seq_offsets,
causal=causal,
num_targets=num_targets,
max_attn_len=max_attn_len,
contextual_seq_len=contextual_seq_len,
)

assert q.shape[2] == v.shape[2], (
f"CUTLASS hstu_attn requires attention_dim == hidden_dim, "
f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: bare assert is stripped under python -O

This constraint depends on user model config (attention_dim vs hidden_dim), so it should survive optimized mode. The rest of the codebase uses torch._assert for runtime validation. Consider switching for consistency:

Suggested change
assert q.shape[2] == v.shape[2], (
f"CUTLASS hstu_attn requires attention_dim == hidden_dim, "
f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}"
)
torch._assert(q.shape[2] == v.shape[2],
f"CUTLASS hstu_attn requires attention_dim == hidden_dim, "
f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}"
)

Also worth validating this earlier at config parse time so users get a clear error before any GPU work begins.


cu_seqlens = seq_offsets.to(torch.int32)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: int32 cast could silently wrap on very large cumulative offsets

If seq_offsets[-1] exceeds 2^31 - 1 (~2.1B tokens), the .to(torch.int32) silently wraps, producing negative offsets → out-of-bounds CUDA memory access. Unlikely in typical recommendation workloads, but a defensive check would prevent silent corruption:

if not is_fx_tracing():
    torch._assert(seq_offsets[-1] <= torch.iinfo(torch.int32).max,
                  "seq_offsets values exceed int32 range")


if causal:
if max_attn_len > 0:
window_size = (max_attn_len, 0)
else:
window_size = (-1, 0)
else:
window_size = (-1, -1)

num_contexts = None
if contextual_seq_len > 0:
batch_size = seq_offsets.size(0) - 1
num_contexts = torch.full(
(batch_size,),
contextual_seq_len,
dtype=torch.int32,
device=q.device,
)

if num_targets is not None:
num_targets = num_targets.to(torch.int32)

return hstu_attn_varlen_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_seq_len,
max_seq_len,
num_contexts=num_contexts,
num_targets=num_targets,
window_size=window_size,
alpha=alpha,
scaling_seqlen=max_seq_len,
)


@torch.fx.wrap
def cutlass_cached_hstu_mha(
max_seq_len: int,
alpha: float,
delta_q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_offsets: torch.Tensor,
num_targets: Optional[torch.Tensor] = None,
max_attn_len: int = 0,
contextual_seq_len: int = 0,
enable_tma: bool = False,
) -> torch.Tensor:
"""Cached HSTU attention for delta queries.

Falls back to Triton implementation since the CUTLASS kernel does not
support the delta-query pattern with separate q/k sequence lengths.
"""
global _cached_fallback_warned
if not _cached_fallback_warned:
logger.warning(
"CUTLASS kernel does not support cached/delta attention, "
"falling back to Triton kernel."
)
_cached_fallback_warned = True

from tzrec.ops._triton.triton_hstu_attention import triton_cached_hstu_mha

return triton_cached_hstu_mha(
max_seq_len=max_seq_len,
alpha=alpha,
delta_q=delta_q,
k=k,
v=v,
seq_offsets=seq_offsets,
num_targets=num_targets,
max_attn_len=max_attn_len,
contextual_seq_len=contextual_seq_len,
enable_tma=enable_tma,
)
10 changes: 9 additions & 1 deletion tzrec/ops/benchmarks/hstu_attention_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def _get_kernel(provider: str) -> Kernel:
return Kernel.TRITON
elif provider == "pytorch":
return Kernel.PYTORCH
elif provider == "cutlass":
return Kernel.CUTLASS
else:
raise ValueError(f"Unknown provider {provider}")

Expand Down Expand Up @@ -93,6 +95,7 @@ def _flops(
@click.option("--bench-backward", type=bool, default=True)
@click.option("--bench-forward", type=bool, default=True)
@click.option("--bench-pytorch", type=bool, default=False)
@click.option("--bench-cutlass", type=bool, default=False)
@click.option("--report-flops", type=bool, default=False)
@click.option("--return-result", type=bool, default=False)
@click.option("--max-attn-len", type=int, default=0)
Expand All @@ -112,6 +115,7 @@ def main( # noqa: C901
bench_backward: bool,
bench_forward: bool,
bench_pytorch: bool,
bench_cutlass: bool,
report_flops: bool,
return_result: bool,
max_attn_len: int,
Expand All @@ -132,6 +136,10 @@ def main( # noqa: C901
line_vals = ["triton"]
line_names = ["Triton"]
styles = [("red", "-")]
if bench_cutlass:
line_vals.append("cutlass")
line_names.append("CUTLASS")
styles.append(("blue", "-"))
if bench_pytorch:
line_vals.append("pytorch")
line_names.append("PyTorch")
Expand Down Expand Up @@ -252,7 +260,7 @@ def _bench_hstu_attention(
q = q.requires_grad_(True)
k = k.requires_grad_(True)
v = v.requires_grad_(True)
assert provider in ["triton", "pytorch"]
assert provider in ["triton", "pytorch", "cutlass"]
if has_delta_q:
fn = lambda: delta_hstu_mha( # noqa E731
max_seq_len=seq_len,
Expand Down
31 changes: 24 additions & 7 deletions tzrec/ops/hstu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ def hstu_mha(
torch._assert(v.shape[1] == H, "wrong v shape[1]")
torch._assert(causal, "only support causal attention")

if kernel in [Kernel.TRITON]:
if not is_fx_tracing() and kernel == Kernel.TRITON:
if kernel in [Kernel.TRITON, Kernel.CUTLASS]:
if not is_fx_tracing():
torch._assert(q.is_cuda, "q must be CUDA tensor")
torch._assert(k.is_cuda, "k must be CUDA tensor")
torch._assert(v.is_cuda, "v must be CUDA tensor")
torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor")
torch._assert(dropout_pr < 1e-6, "dropout for triton path not implemented")
torch._assert(
dropout_pr < 1e-6, "dropout for triton/cutlass not implemented"
)
torch._assert(
min_full_attn_seq_len == 0, "min_full_attn_seq_len not implemented"
)
Expand All @@ -69,7 +71,22 @@ def hstu_mha(
v = switch_to_contiguous_if_needed(v)
seq_offsets = seq_offsets.contiguous()

if kernel == Kernel.TRITON:
if kernel == Kernel.CUTLASS:
from tzrec.ops._cuda.cutlass_hstu_attention import cutlass_hstu_mha

return cutlass_hstu_mha(
max_seq_len=max_seq_len,
alpha=alpha,
q=q,
k=k,
v=v,
seq_offsets=seq_offsets,
causal=causal,
num_targets=num_targets,
max_attn_len=max_attn_len,
contextual_seq_len=contextual_seq_len,
)
elif kernel == Kernel.TRITON:
from tzrec.ops._triton.triton_hstu_attention import triton_hstu_mha

return triton_hstu_mha(
Expand Down Expand Up @@ -129,8 +146,8 @@ def delta_hstu_mha(
torch._assert(k.shape[2] == D, "wrong k shape[2]")
torch._assert(v.dim() == 3, "v must be 3-D")
torch._assert(v.shape[1] == H, "wrong v shape[1]")
if kernel in [Kernel.TRITON]:
if not is_fx_tracing() and kernel == Kernel.TRITON:
if kernel in [Kernel.TRITON, Kernel.CUTLASS]:
if not is_fx_tracing():
torch._assert(delta_q.is_cuda, "q must be CUDA tensor")
torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor")
if num_targets is not None:
Expand All @@ -140,7 +157,7 @@ def delta_hstu_mha(
k = switch_to_contiguous_if_needed(k)
v = switch_to_contiguous_if_needed(v)

if kernel == Kernel.TRITON:
if kernel in [Kernel.TRITON, Kernel.CUTLASS]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: delta_hstu_mha bypasses the CUTLASS wrapper entirely

When kernel == Kernel.CUTLASS, this dispatches directly to triton_cached_hstu_mha, completely skipping the cutlass_cached_hstu_mha wrapper defined in cutlass_hstu_attention.py. This means:

  • Users selecting CUTLASS get zero feedback that delta attention silently falls back to Triton (the warning in cutlass_cached_hstu_mha is dead code)
  • The @torch.fx.wrap on cutlass_cached_hstu_mha is never invoked, which may affect FX tracing correctness
  • If CUTLASS later adds delta support, this dispatch won't route to it

Should add a separate CUTLASS branch here (like hstu_mha does) that routes through cutlass_cached_hstu_mha:

if kernel == Kernel.CUTLASS:
    from tzrec.ops._cuda.cutlass_hstu_attention import cutlass_cached_hstu_mha
    return cutlass_cached_hstu_mha(...)
elif kernel == Kernel.TRITON:
    from tzrec.ops._triton.triton_hstu_attention import triton_cached_hstu_mha
    return triton_cached_hstu_mha(...)

from tzrec.ops._triton.triton_hstu_attention import triton_cached_hstu_mha

return triton_cached_hstu_mha(
Expand Down
63 changes: 63 additions & 0 deletions tzrec/ops/hstu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,69 @@ def test_cache(
real_delta_out,
)

@unittest.skipIf(*gpu_unavailable)
# pyre-ignore
@given(
batch_size=st.integers(4, 8),
heads=st.integers(1, 4),
max_uih_len=st.sampled_from([20, 100, 128, 256]),
max_targets=st.sampled_from([20, 512]),
attn_dim=st.sampled_from([32, 64, 128]),
causal=st.sampled_from([True]),
has_multiple_targets=st.sampled_from([True, False]),
dtype=st.sampled_from(get_test_dtypes([torch.bfloat16])),
has_max_attn_len=st.sampled_from([True, False]),
contextual_seq_len=st.sampled_from([0, 10]),
)
@settings(
verbosity=Verbosity.verbose,
max_examples=20,
deadline=None,
)
# pyre-ignore[2]
def test_attn_cutlass(self, *args, **kwargs) -> None:
hidden_dim = kwargs.pop("attn_dim")
test_attn(
*args,
**kwargs,
attn_dim=hidden_dim,
hidden_dim=hidden_dim,
test_backward=True,
ref_kernel=Kernel.PYTORCH,
real_kernel=Kernel.CUTLASS,
)

@unittest.skipIf(*gpu_unavailable)
# pyre-ignore
@given(
batch_size=st.integers(4, 8),
heads=st.integers(1, 4),
max_uih_len=st.sampled_from([100, 128, 256]),
max_targets=st.sampled_from([20, 512]),
delta_size=st.sampled_from([20, 512]),
attn_dim=st.sampled_from([32, 64, 128]),
has_multiple_targets=st.sampled_from([True, False]),
dtype=st.sampled_from(get_test_dtypes([torch.bfloat16])),
has_max_attn_len=st.sampled_from([False, True]),
contextual_seq_len=st.sampled_from([0, 10]),
)
@settings(
verbosity=Verbosity.verbose,
max_examples=20,
deadline=None,
)
# pyre-ignore[2]
def test_delta_attn_cutlass(self, *args, **kwargs) -> None:
hidden_dim = kwargs.pop("attn_dim")
test_delta_attn(
*args,
**kwargs,
attn_dim=hidden_dim,
hidden_dim=hidden_dim,
ref_kernel=Kernel.PYTORCH,
real_kernel=Kernel.CUTLASS,
)


if __name__ == "__main__":
unittest.main()
Loading
Loading