-
Notifications
You must be signed in to change notification settings - Fork 63
[feat] add CUTLASS kernel backend for HSTU attention #465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
e624365
410210f
adb3c60
322aa91
23dfe50
6fb33eb
ab97406
06645a1
e114ab6
6028d32
d3d079d
785befc
07d27e8
97e3e25
c437893
590bb4f
9956957
9f39c32
a22f36b
f7d020b
27420b0
c4898fe
f8bd059
ef3e24e
55b33e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,7 @@ | ||
| dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260331.bea6b4b.cu129-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" | ||
| dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260331.bea6b4b.cu129-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" | ||
| dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260331.bea6b4b.cu129-cp312-cp312-linux_x86_64.whl ; python_version=="3.12" | ||
| hstu_attn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" | ||
| hstu_attn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" | ||
| hstu_attn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/hstu/cu129/hstu_attn-0.1.0%2Bbea6b4b.cu12.9-cp312-cp312-linux_x86_64.whl ; python_version=="3.12" | ||
| torch_fx_tool @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/rtp/torch_fx_tool-0.0.1%2B20251201.8c109c4-py3-none-any.whl |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,4 +18,4 @@ class Kernel(Enum): | |
|
|
||
| TRITON = "TRITON" | ||
| PYTORCH = "PYTORCH" | ||
| CUDA = "CUDA" | ||
| CUTLASS = "CUTLASS" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # Copyright (c) 2025, Alibaba Group; | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,177 @@ | ||||||||||||||||||
| # Copyright (c) 2025, Alibaba Group; | ||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||
|
|
||||||||||||||||||
| import logging | ||||||||||||||||||
| from typing import Optional | ||||||||||||||||||
|
|
||||||||||||||||||
| import torch | ||||||||||||||||||
| from hstu_attn import hstu_attn_varlen_func | ||||||||||||||||||
|
|
||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||
|
|
||||||||||||||||||
| _triton_fallback_warned = False | ||||||||||||||||||
| _cached_fallback_warned = False | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _needs_triton_fallback( | ||||||||||||||||||
| max_attn_len: int, | ||||||||||||||||||
| contextual_seq_len: int, | ||||||||||||||||||
| num_targets: Optional[torch.Tensor], | ||||||||||||||||||
| ) -> bool: | ||||||||||||||||||
| """Check if we need to fall back to triton. | ||||||||||||||||||
|
|
||||||||||||||||||
| The CUTLASS kernel does not support combining local window attention | ||||||||||||||||||
| (max_attn_len > 0) with context or target masking. | ||||||||||||||||||
| """ | ||||||||||||||||||
| global _triton_fallback_warned | ||||||||||||||||||
| has_local_window = max_attn_len > 0 | ||||||||||||||||||
| has_context_or_target = contextual_seq_len > 0 or num_targets is not None | ||||||||||||||||||
| needs_fallback = has_local_window and has_context_or_target | ||||||||||||||||||
| if needs_fallback and not _triton_fallback_warned: | ||||||||||||||||||
| logger.warning( | ||||||||||||||||||
| "CUTLASS kernel does not support combining local window attention " | ||||||||||||||||||
| "(max_attn_len > 0) with context/target masking, " | ||||||||||||||||||
| "falling back to Triton kernel." | ||||||||||||||||||
| ) | ||||||||||||||||||
| _triton_fallback_warned = True | ||||||||||||||||||
| return needs_fallback | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @torch.fx.wrap | ||||||||||||||||||
| def cutlass_hstu_mha( | ||||||||||||||||||
| max_seq_len: int, | ||||||||||||||||||
| alpha: float, | ||||||||||||||||||
| q: torch.Tensor, | ||||||||||||||||||
| k: torch.Tensor, | ||||||||||||||||||
| v: torch.Tensor, | ||||||||||||||||||
| seq_offsets: torch.Tensor, | ||||||||||||||||||
| causal: bool = True, | ||||||||||||||||||
| num_targets: Optional[torch.Tensor] = None, | ||||||||||||||||||
| max_attn_len: int = 0, | ||||||||||||||||||
| contextual_seq_len: int = 0, | ||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||
| """CUTLASS-based HSTU multi-head attention. | ||||||||||||||||||
|
|
||||||||||||||||||
| Args: | ||||||||||||||||||
| max_seq_len: maximum sequence length in the batch. | ||||||||||||||||||
| alpha: scaling factor for attention scores. | ||||||||||||||||||
| q: query tensor of shape (total, nheads, attn_dim). | ||||||||||||||||||
| k: key tensor of shape (total, nheads, attn_dim). | ||||||||||||||||||
| v: value tensor of shape (total, nheads, hidden_dim). | ||||||||||||||||||
| seq_offsets: cumulative sequence offsets of shape (batch_size + 1,). | ||||||||||||||||||
| causal: whether to apply causal masking. | ||||||||||||||||||
| num_targets: number of target tokens per batch element. | ||||||||||||||||||
| max_attn_len: maximum attention window length (0 means unlimited). | ||||||||||||||||||
| contextual_seq_len: number of contextual tokens per sequence. | ||||||||||||||||||
|
|
||||||||||||||||||
| Returns: | ||||||||||||||||||
| output tensor of shape (total, nheads, hidden_dim). | ||||||||||||||||||
| """ | ||||||||||||||||||
| if _needs_triton_fallback(max_attn_len, contextual_seq_len, num_targets): | ||||||||||||||||||
| from tzrec.ops._triton.triton_hstu_attention import triton_hstu_mha | ||||||||||||||||||
|
|
||||||||||||||||||
| return triton_hstu_mha( | ||||||||||||||||||
| max_seq_len=max_seq_len, | ||||||||||||||||||
| alpha=alpha, | ||||||||||||||||||
| q=q, | ||||||||||||||||||
| k=k, | ||||||||||||||||||
| v=v, | ||||||||||||||||||
| seq_offsets=seq_offsets, | ||||||||||||||||||
| causal=causal, | ||||||||||||||||||
| num_targets=num_targets, | ||||||||||||||||||
| max_attn_len=max_attn_len, | ||||||||||||||||||
| contextual_seq_len=contextual_seq_len, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| assert q.shape[2] == v.shape[2], ( | ||||||||||||||||||
| f"CUTLASS hstu_attn requires attention_dim == hidden_dim, " | ||||||||||||||||||
| f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
||||||||||||||||||
| assert q.shape[2] == v.shape[2], ( | |
| f"CUTLASS hstu_attn requires attention_dim == hidden_dim, " | |
| f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}" | |
| ) | |
| torch._assert(q.shape[2] == v.shape[2], | |
| f"CUTLASS hstu_attn requires attention_dim == hidden_dim, " | |
| f"got q.shape[2]={q.shape[2]} != v.shape[2]={v.shape[2]}" | |
| ) |
Also worth validating this earlier at config parse time so users get a clear error before any GPU work begins.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: int32 cast could silently wrap on very large cumulative offsets
If seq_offsets[-1] exceeds 2^31 - 1 (~2.1B tokens), the .to(torch.int32) silently wraps, producing negative offsets → out-of-bounds CUDA memory access. Unlikely in typical recommendation workloads, but a defensive check would prevent silent corruption:
if not is_fx_tracing():
torch._assert(seq_offsets[-1] <= torch.iinfo(torch.int32).max,
"seq_offsets values exceed int32 range")| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,13 +54,15 @@ def hstu_mha( | |
| torch._assert(v.shape[1] == H, "wrong v shape[1]") | ||
| torch._assert(causal, "only support causal attention") | ||
|
|
||
| if kernel in [Kernel.TRITON]: | ||
| if not is_fx_tracing() and kernel == Kernel.TRITON: | ||
| if kernel in [Kernel.TRITON, Kernel.CUTLASS]: | ||
| if not is_fx_tracing(): | ||
| torch._assert(q.is_cuda, "q must be CUDA tensor") | ||
| torch._assert(k.is_cuda, "k must be CUDA tensor") | ||
| torch._assert(v.is_cuda, "v must be CUDA tensor") | ||
| torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor") | ||
| torch._assert(dropout_pr < 1e-6, "dropout for triton path not implemented") | ||
| torch._assert( | ||
| dropout_pr < 1e-6, "dropout for triton/cutlass not implemented" | ||
| ) | ||
| torch._assert( | ||
| min_full_attn_seq_len == 0, "min_full_attn_seq_len not implemented" | ||
| ) | ||
|
|
@@ -69,7 +71,22 @@ def hstu_mha( | |
| v = switch_to_contiguous_if_needed(v) | ||
| seq_offsets = seq_offsets.contiguous() | ||
|
|
||
| if kernel == Kernel.TRITON: | ||
| if kernel == Kernel.CUTLASS: | ||
| from tzrec.ops._cuda.cutlass_hstu_attention import cutlass_hstu_mha | ||
|
|
||
| return cutlass_hstu_mha( | ||
| max_seq_len=max_seq_len, | ||
| alpha=alpha, | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| seq_offsets=seq_offsets, | ||
| causal=causal, | ||
| num_targets=num_targets, | ||
| max_attn_len=max_attn_len, | ||
| contextual_seq_len=contextual_seq_len, | ||
| ) | ||
| elif kernel == Kernel.TRITON: | ||
| from tzrec.ops._triton.triton_hstu_attention import triton_hstu_mha | ||
|
|
||
| return triton_hstu_mha( | ||
|
|
@@ -129,8 +146,8 @@ def delta_hstu_mha( | |
| torch._assert(k.shape[2] == D, "wrong k shape[2]") | ||
| torch._assert(v.dim() == 3, "v must be 3-D") | ||
| torch._assert(v.shape[1] == H, "wrong v shape[1]") | ||
| if kernel in [Kernel.TRITON]: | ||
| if not is_fx_tracing() and kernel == Kernel.TRITON: | ||
| if kernel in [Kernel.TRITON, Kernel.CUTLASS]: | ||
| if not is_fx_tracing(): | ||
| torch._assert(delta_q.is_cuda, "q must be CUDA tensor") | ||
| torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor") | ||
| if num_targets is not None: | ||
|
|
@@ -140,7 +157,7 @@ def delta_hstu_mha( | |
| k = switch_to_contiguous_if_needed(k) | ||
| v = switch_to_contiguous_if_needed(v) | ||
|
|
||
| if kernel == Kernel.TRITON: | ||
| if kernel in [Kernel.TRITON, Kernel.CUTLASS]: | ||
|
||
| from tzrec.ops._triton.triton_hstu_attention import triton_cached_hstu_mha | ||
|
|
||
| return triton_cached_hstu_mha( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.