[feat] add CUTLASS kernel backend for HSTU attention#465
[feat] add CUTLASS kernel backend for HSTU attention#465tiankongdeguiji merged 25 commits intoalibaba:masterfrom
Conversation
Integrate the hstu_attn CUTLASS-based fused attention kernel as a new Kernel.CUTLASS backend option for DlrmHSTU. This provides an alternative high-performance attention path using CUTLASS kernels that support Ampere, Ada, and Hopper GPUs. - Rename Kernel.CUDA to Kernel.CUTLASS in ops enum and proto - Add cutlass_hstu_mha wrapper adapting hstu_attn_varlen_func - Add CUTLASS dispatch in hstu_mha() and delta_hstu_mha() - Add unit tests and integration test with CUTLASS config - Update docs with CUTLASS backend usage and FAQ Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When local window attention (max_attn_len > 0) is combined with context or target masking, fall back to Triton since the CUTLASS kernel does not support this combination. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| assert q.shape[2] == v.shape[2], ( | ||
| f"CUTLASS hstu_attn requires attention_dim == hidden_dim, " | ||
| f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}" | ||
| ) |
There was a problem hiding this comment.
Nit: bare assert is stripped under python -O
This constraint depends on user model config (attention_dim vs hidden_dim), so it should survive optimized mode. The rest of the codebase uses torch._assert for runtime validation. Consider switching for consistency:
| assert q.shape[2] == v.shape[2], ( | |
| f"CUTLASS hstu_attn requires attention_dim == hidden_dim, " | |
| f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}" | |
| ) | |
| torch._assert(q.shape[2] == v.shape[2], | |
| f"CUTLASS hstu_attn requires attention_dim == hidden_dim, " | |
| f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}" | |
| ) |
Also worth validating this earlier at config parse time so users get a clear error before any GPU work begins.
| f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}" | ||
| ) | ||
|
|
||
| cu_seqlens = seq_offsets.to(torch.int32) |
There was a problem hiding this comment.
Minor: int32 cast could silently wrap on very large cumulative offsets
If seq_offsets[-1] exceeds 2^31 - 1 (~2.1B tokens), the .to(torch.int32) silently wraps, producing negative offsets → out-of-bounds CUDA memory access. Unlikely in typical recommendation workloads, but a defensive check would prevent silent corruption:
if not is_fx_tracing():
torch._assert(seq_offsets[-1] <= torch.iinfo(torch.int32).max,
"seq_offsets values exceed int32 range")| @unittest.skipIf(*gpu_unavailable) | ||
| def test_rank_dlrm_hstu_cutlass_train_eval_export(self): | ||
| self.success = utils.test_train_eval( | ||
| "tzrec/tests/configs/dlrm_hstu_cutlass_kuairand_1k.config", self.test_dir |
There was a problem hiding this comment.
Test name says "export" but doesn't test export
Other HSTU integration tests (e.g., the AOT variant above) call test_export after train+eval. Since CUTLASS introduces a new @torch.fx.wrap function, verifying export/tracing works correctly seems especially valuable. Either add the export step or rename to test_rank_dlrm_hstu_cutlass_train_eval.
tzrec/ops/hstu_attention.py
Outdated
| v = switch_to_contiguous_if_needed(v) | ||
|
|
||
| if kernel == Kernel.TRITON: | ||
| if kernel in [Kernel.TRITON, Kernel.CUTLASS]: |
There was a problem hiding this comment.
Bug: delta_hstu_mha bypasses the CUTLASS wrapper entirely
When kernel == Kernel.CUTLASS, this dispatches directly to triton_cached_hstu_mha, completely skipping the cutlass_cached_hstu_mha wrapper defined in cutlass_hstu_attention.py. This means:
- Users selecting CUTLASS get zero feedback that delta attention silently falls back to Triton (the warning in
cutlass_cached_hstu_mhais dead code) - The
@torch.fx.wraponcutlass_cached_hstu_mhais never invoked, which may affect FX tracing correctness - If CUTLASS later adds delta support, this dispatch won't route to it
Should add a separate CUTLASS branch here (like hstu_mha does) that routes through cutlass_cached_hstu_mha:
if kernel == Kernel.CUTLASS:
from tzrec.ops._cuda.cutlass_hstu_attention import cutlass_cached_hstu_mha
return cutlass_cached_hstu_mha(...)
elif kernel == Kernel.TRITON:
from tzrec.ops._triton.triton_hstu_attention import triton_cached_hstu_mha
return triton_cached_hstu_mha(...)
Code Review SummaryNice work integrating the CUTLASS kernel backend! The wrapper design with graceful Triton fallback is well thought out, and the hypothesis-based tests provide good coverage. A few issues to address: Must Fix
Should Fix
Nice to Have
🤖 Generated with Claude Code |
- Move local-window+context/target fallback from cutlass_hstu_attention.py to hstu_attention.py dispatch layer so sort_by_length and enable_tma are preserved when falling back to Triton - Move cached/delta fallback to hstu_attention.py dispatch layer - Remove cutlass_cached_hstu_mha (no longer needed) - Add test_export and test_predict to CUTLASS integration test to verify FX tracing / AOT export with @torch.fx.wrap - Replace bare assert with raise ValueError for python -O safety Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…port The previous @torch.fx.wrap only worked for FX tracing. torch.export (used by AOT Inductor) uses torchdynamo which bypassed the FX leaf hint and tried to trace into hstu_attn_cuda.varlen_fwd C++ extension, causing AOT export to fail. Fix by splitting into forward/backward torch.library custom_ops with register_autograd wiring, so torch.export treats the whole call as an opaque op while backward still works for training. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Two issues preventing CUTLASS test_export from passing: 1. FX symbolic_trace broke on switch_to_contiguous_if_needed() control flow inside the shared TRITON/CUTLASS preprocessing block. Move contiguous/assertion preprocessing to the TRITON-only path; CUTLASS now dispatches directly since cutlass_hstu_mha is @torch.fx.wrap'd and handles its own contiguous internally. 2. The CUTLASS kernel only supports fp16/bf16, but during the export eager run tensors flow as fp32 (autocast doesn't propagate through the traced graph). Add unconditional .to(bfloat16) casts before the CUTLASS call in hstu_mha — unconditional casts are FX-traceable, get baked into the traced graph, and are no-ops at runtime when inputs are already bf16. Cast the output back to the original dtype so downstream layers see matching dtypes. Validated locally: integration test now reaches the AOT Inductor compile stage (past the actual model execution), which is the furthest a local environment can get without the libcuda.so stub. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…xport Replace the targeted .to(bfloat16) workaround in hstu_mha with a proper, framework-level AutocastWrapper applied around the FX-traced sparse/dense sub-graphs in tzrec/acc/aot_utils.py and tzrec/acc/trt_utils.py. Why the previous explicit cast was a workaround: train_config.mixed_precision was being configured to BF16 but never reached the export path, because torch.amp.autocast is a runtime dispatcher feature that FX symbolic_trace cannot capture. Wrapping the raw model with autocast and then tracing simply loses the autocast context before the scripted/exported artifact is produced. The fix: wrap the FX-traced sub-graphs AFTER tracing, just before torch.jit.script / torch.export.export. This way: - For sparse path (torch.jit.script): AutocastWrapper.forward has a `with torch.autocast(dtype=<literal>)` block branched on a Final[int] compile-time flag. TorchScript honors the block because the dtype is a literal in each branch, so autocast is active at scripted-execution time. - For dense path (torch.export + AOT Inductor): DenseAutocastWrapper puts the `with torch.autocast(...)` inside nn.Module.forward. torch.export captures it as a `wrap_with_autocast` Higher Order Op, which AOT Inductor lowers to proper dtype casts in the compiled artifact. - For the initial eager run in export_util.py (before any tracing), a plain torch.amp.autocast context manager around the model() call handles the eager path. This means the CUTLASS kernel (and any future dtype-sensitive kernel) gets bf16 q/k/v naturally because the upstream linear layers run under autocast and produce bf16 outputs. The explicit per-kernel cast is no longer needed. mixed_precision is threaded from pipeline_config.train_config.mixed_precision → export_model_normal → export_model_aot / export_model_trt → wrappers. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
AOT-packaged models reference custom ops by their qualified name (e.g. ``tzrec::cutlass_hstu_mha_fwd``). PyTorch only knows about an op once its registering Python module has been imported. During export, the module is imported lazily via the dispatch in hstu_mha, which is sufficient for the export phase. But during predict, the aoti_load_package call happens without the dispatch ever running, so PyTorch raises ``RuntimeError: Could not find schema for tzrec::cutlass_hstu_mha_fwd``. Fix by eagerly importing the cutlass module from aot_utils.py (wrapped in try/except so the import is optional for environments without the hstu_attn native dependency). aot_utils is imported by both tzrec/main.py export and predict paths, so the custom op gets registered before aoti_load_package runs. Verified locally: the integration test's export + predict now reach the inference phase — model loads cleanly, first forward call produces valid predictions. Any remaining local-only flakiness is env-specific (stub libcuda.so vs real driver under worker threads). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reproduced the CI predict timeout inside the CI Docker image
(mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/tzrec-devel:1.1)
and used py-spy to dump the hung predict workers. Stack trace showed two
``_forward_loop`` threads stuck inside the AOT-Inductor compiled model
``__call__``, one of them blocked at the return of ``_cutlass_hstu_mha_fwd``
through this dispatch chain:
__call__ (torch/export/pt2_archive/_package.py)
autograd_impl (torch/_library/autograd.py)
forward_no_grad (torch/_library/autograd.py)
backend_impl (torch/_library/custom_ops.py)
inner (torch/_compile.py) # @torch.compiler.disable
wrapped_fn (torch/_library/custom_ops.py)
_cutlass_hstu_mha_fwd
The other thread was blocked at the entry of ``__call__``. The combination
of ``@torch.library.custom_op``'s built-in ``@torch.compiler.disable``
wrapper plus the autograd_impl dispatch wrapper plus AOTI's compiled-model
dispatch deadlocks under concurrent calls from the predict ``_forward_loop``
worker threads. Triton's path doesn't hit this because ``triton_op`` is a
different code path that doesn't add those wrappers.
Switch the cutlass op registration from ``@torch.library.custom_op`` to
the lower-level ``torch.library.Library`` / ``torch.library.define`` /
``Library.impl`` API. This avoids the ``@torch.compiler.disable`` wrapper
and the autograd_impl wrapper layer entirely. Verified locally that the
op now has only ``CUDA`` and ``Meta`` dispatch keys (no ``Autograd``,
no ``AutogradCUDA``).
Backward is still wired up via a Python ``torch.autograd.Function`` that
calls ``torch.ops.tzrec.cutlass_hstu_mha_fwd`` in forward and
``torch.ops.tzrec.cutlass_hstu_mha_bwd`` in backward. The top-level
``cutlass_hstu_mha`` dispatcher routes through the autograd Function only
when ``torch.is_grad_enabled()`` and an input ``requires_grad``; under
no_grad / inference / FX-traced graphs it calls the bare op directly,
which keeps the inference dispatch a single layer deep and removes the
deadlock.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ecision The sparse sub-graph produced by split_model is purely embedding lookups, which don't have autocast rules and don't benefit from AMP. Wrapping sparse in an AutocastWrapper before jit.script added compile-time dtype constant handling complexity (Final[int] flag, literal-dtype branches) for zero runtime benefit. Drop AutocastWrapper entirely and keep only DenseAutocastWrapper for the dense sub-graph, which is where the CUTLASS attention op actually lives after the sparse/dense split. Also add an ``export_config.mixed_precision`` field so users can opt into AMP for export/inference independently of training. Resolution order in export_util: export_config.mixed_precision if set, else train_config.mixed_precision. Changes: - protos/export.proto: add optional ``mixed_precision`` field - models/model.py: remove AutocastWrapper (keep DenseAutocastWrapper) - utils/export_util.py: resolve mixed_precision from export_config first, then fall back to train_config - acc/aot_utils.py: don't wrap sparse_model_traced before jit.script, drop the sparse-side eager autocast context - acc/trt_utils.py: same cleanup for the TRT path - tests/rank_integration_test.py: pin ``predict_threads=1`` in the cutlass integration test. hstu_attn_cuda's pybind11 binding does not release the GIL and the AOTI runtime deadlocks between the two predict forward worker threads; single-worker predict sidesteps this. Root cause to be addressed upstream in a follow-up. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…AOT branches Code review passes flagged several duplications in the previous mixed_precision refactor. Addresses them: - Add ``mixed_precision_to_dtype(str) -> Optional[torch.dtype]`` helper next to the existing quant-dtype helpers in ``tzrec/acc/utils.py``. Replaces the same if/elif block that lived in ``TrainWrapper``, ``PredictWrapper`` (both in ``tzrec/models/model.py``) and the previous ``export_model_normal`` resolution. - Add ``resolve_mixed_precision(pipeline_config) -> str`` helper in the same file. Encapsulates the ``export_config.mixed_precision or train_config.mixed_precision`` precedence rule. Removes the inline HasField/fallback block from ``export_util.py`` and a four-line narrative comment that just described what the code was doing. - Unify the ``acc_utils.is_trt()`` / ``acc_utils.is_aot()`` branches in ``export_model_normal``: the autocast eager-run, result logging, and ``split_model`` call were identical. Hoist them above the branch and keep only the terminal ``export_model_trt`` / ``export_model_aot`` call site-specific. - Expand the ``DenseAutocastWrapper`` docstring to explain why it stores an integer flag (``Final[int]``) instead of a ``torch.dtype`` attribute — ``torch.jit.script`` only honors a literal dtype in the ``torch.autocast`` call, so we need the branched dispatch. - Update the stale ``AutocastWrapper`` comment in ``tzrec/ops/hstu_attention.py`` to reference ``DenseAutocastWrapper`` (the wrapper was renamed away in the previous commit). Verified locally that imports are clean (no circular dep between ``tzrec.models.model`` and ``tzrec.acc.utils``), the helper handles None/empty/valid/invalid inputs correctly, and the ``resolve_mixed_precision`` precedence chain works as intended. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Match the URL convention used by other third-party wheels in requirements/extra.txt (dynamicemb, torch_fx_tool). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Kernel.CUTLASS only applies to the HSTU attention op itself (via the cutlass_hstu_mha custom op). All other ops inside the STU layer — layer_norm, addmm for the qkv projection, silu, and compute_output — have no CUTLASS implementation and were previously falling through to the PyTorch backend, which is much slower than the existing Triton implementations. Introduce a sub_op_kernel(kernel) helper in tzrec/ops/hstu_compute.py that maps Kernel.CUTLASS -> Kernel.TRITON and passes every other kernel through. Use it at the three call sites where the HSTU layer dispatches non-attention sub-ops: - hstu_preprocess_and_attention: for the hstu_compute_uqvk call (the attention call still receives the original kernel) - STULayer.forward: for the hstu_compute_output call - STULayer.cached_forward: for both hstu_compute_uqvk and hstu_compute_output (delta_hstu_mha already falls back internally) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…_op_kernel helper
Previous approach routed the fallback through a sub_op_kernel(kernel)
helper that callers had to remember to apply at the three call sites
in STULayer / hstu_preprocess_and_attention. That's fragile: new ops
added to the STU layer or new call sites would silently fall through
to the PyTorch backend whenever kernel == CUTLASS.
Instead, let each op's entry function handle the CUTLASS fallback
itself. At the top of every non-attention op we add:
if kernel == Kernel.CUTLASS:
kernel = Kernel.TRITON
This is local, obvious, and cannot be forgotten at call sites. It
also lets any direct caller (positional_encoder, preprocessors,
content_encoder, action_encoder, norm modules, etc.) use
`kernel=self.kernel()` with no special-casing — CUTLASS will
naturally run through Triton inside each op.
Ops updated (all take `kernel: Kernel` as a public parameter):
- tzrec/ops/mm.py: addmm
- tzrec/ops/layer_norm.py: layer_norm, rms_norm, swish_layer_norm
- tzrec/ops/position.py: add_positional_embeddings,
add_timestamp_positional_embeddings
- tzrec/ops/jagged_tensors.py: concat_2D_jagged, split_2D_jagged,
jagged_dense_bmm_broadcast_add
- tzrec/ops/hstu_compute.py: hstu_compute_uqvk, hstu_compute_output
(hstu_preprocess_and_attention stays as an orchestrator: its
inner hstu_compute_uqvk and hstu_mha calls already route CUTLASS
correctly via the per-op entries)
- tzrec/ops/hstu_attention.py: delta_hstu_mha's explicit
CUTLASS->TRITON fallback block replaced with the standard
two-liner at the top
hstu_mha keeps its CUTLASS branch: it is the only op with an actual
CUTLASS implementation (cutlass_hstu_mha).
Reverts the sub_op_kernel helper added in the previous commit and
restores stu.py's call sites to pass ``kernel=self.kernel()`` directly.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| seq_offsets = seq_offsets.contiguous() | ||
| cu_seqlens = seq_offsets.to(torch.int32) |
There was a problem hiding this comment.
perf/correctness: The .contiguous() on seq_offsets (line 311) is wasted — the .to(torch.int32) on the next line always returns a new contiguous tensor regardless.
More importantly, the int32 cast silently overflows if cumulative offsets exceed INT32_MAX. Consider adding a bounds check:
| seq_offsets = seq_offsets.contiguous() | |
| cu_seqlens = seq_offsets.to(torch.int32) | |
| cu_seqlens = seq_offsets.to(torch.int32) |
And optionally guard with:
if seq_offsets.max() > torch.iinfo(torch.int32).max:
raise ValueError(f"cu_seqlens exceed int32 range: max={seq_offsets.max().item()}")| None, # last_page_lens | ||
| None, # cu_seqlens_t | ||
| ) | ||
| return out[:, :, :head_dim].reshape(-1, num_heads, head_dim).contiguous() |
There was a problem hiding this comment.
perf: .reshape() on a non-contiguous input (from the slice) already returns a new contiguous tensor, so the trailing .contiguous() is a no-op. If the shape isn't actually changing (N is already the first dim), this can be simplified:
| return out[:, :, :head_dim].reshape(-1, num_heads, head_dim).contiguous() | |
| return out[:, :, :head_dim].contiguous() |
| window_size_right: int, | ||
| alpha: float, | ||
| ) -> torch.Tensor: | ||
| return torch.empty_like(q) |
There was a problem hiding this comment.
correctness: The meta function returns torch.empty_like(q), which happens to be correct today because cutlass_hstu_mha enforces q.shape[2] == v.shape[2]. However, semantically the output shape derives from v's hidden_dim, not q's attn_dim. If the constraint is ever relaxed, this meta would silently return the wrong shape during tracing.
Consider using torch.empty_like(v) (or at minimum mirroring the q.shape[2] == v.shape[2] assertion here) to make the contract explicit.
tzrec/ops/hstu_attention.py
Outdated
| ) | ||
| from tzrec.ops.utils import switch_to_contiguous_if_needed | ||
|
|
||
| logger = logging.getLogger(__name__) |
There was a problem hiding this comment.
nit: The rest of the codebase uses from tzrec.utils.logging_util import logger. Using logging.getLogger(__name__) here creates a separate logger that won't inherit project-level formatting/handlers.
| logger = logging.getLogger(__name__) | |
| from tzrec.utils.logging_util import logger |
| if kernel == Kernel.CUTLASS: | ||
| kernel = Kernel.TRITON |
There was a problem hiding this comment.
test coverage / naming: delta_hstu_mha unconditionally falls back CUTLASS→Triton here, which means test_delta_attn_cutlass in the test file is actually testing Triton-vs-PyTorch equivalence — no CUTLASS code path is exercised. Consider either:
- Renaming the test to
test_delta_attn_cutlass_fallbackto make this clear, or - Adding a
logger.warninghere (matching the pattern inhstu_mha) so the fallback is observable.
Code Review SummaryWell-structured PR. The low-level Inline comments posted (6)
Additional observations (not posted inline)
|
- cutlass_hstu_attention.py (reshape return): drop the redundant trailing ``.contiguous()`` after ``.reshape(-1, num_heads, head_dim)``. When ``out[:, :, :head_dim]`` is non-contiguous (attn_dim < out_dim) ``.reshape`` already allocates a contiguous copy; when the slice is a no-op the reshape is a view of an already-contiguous tensor. The extra ``.contiguous()`` was always a no-op. - cutlass_hstu_attention.py (meta impl): change the fake/meta kernel for ``cutlass_hstu_mha_fwd`` from ``empty_like(q)`` to ``empty_like(v)``. The output shape is ``(total, nheads, hidden_dim)`` which matches v; keying the fake off v keeps the meta correct if the current ``attention_dim == hidden_dim`` constraint is ever relaxed. - cutlass_hstu_attention.py (seq_offsets prep): drop the redundant ``seq_offsets.contiguous()`` before ``.to(torch.int32)`` — the dtype cast always allocates a fresh contiguous tensor. Document the int32 cu_seqlens limit in the docstring so users hitting > 2**31 tokens get an actionable reference (no runtime ``.item()``/GPU-sync guard, which would either break torch.export or add per-call latency). - hstu_attention.py: use ``tzrec.utils.logging_util.logger`` instead of a per-module ``logging.getLogger(__name__)`` to match the project convention. Drop the now-unused ``import logging``. - hstu_attention_test.py: remove ``test_delta_attn_cutlass``. The name was misleading: ``delta_hstu_mha`` has no CUTLASS implementation, it falls back to Triton at the op entry, so the test was silently exercising the Triton path. ``test_delta_attn_triton`` already covers that path. Replaced with a NOTE comment explaining why. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary
hstu_attnCUTLASS-based fused attention kernel as a newKernel.CUTLASSbackend for DlrmHSTUKernel.CUDAenum toKernel.CUTLASSin both Python and protobuf definitionscutlass_hstu_mhawrapper intzrec/ops/_cuda/that adapts TorchEasyRec's interface tohstu_attn_varlen_funcattention_dim == hidden_dim; cached inference falls back to TritonTest plan
pytest tzrec/ops/hstu_attention_test.py -k "cutlass" -vpytest tzrec/tests/rank_integration_test.py -k "dlrm_hstu_cutlass" -v🤖 Generated with Claude Code