Skip to content

Commit e0bae00

Browse files
committed
Port softmax ops to libtorch stable ABI
Proof of concept for migrating pybind11 functions to the PyTorch stable ABI. Ports all 8 scaled softmax functions: - Add stable_common.h with stable ABI helpers (tensor allocation, TensorWrapper construction, CUDA stream, dtype converters) - Add registration.cpp with STABLE_TORCH_LIBRARY schema definitions - Rewrite softmax.cpp: at::Tensor -> torch::stable::Tensor, use stable allocation and stream APIs, TORCH_BOX() for impl registration - Remove softmax registrations from pybind.cpp - Update Python callers to use torch.ops.transformer_engine_stable The pattern is mechanical (API translation, no logic changes) and establishes the template for porting the remaining ~70 Category A functions that have no py::handle/py::object dependencies. Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 9d77dcb commit e0bae00

File tree

10 files changed

+331
-211
lines changed

10 files changed

+331
-211
lines changed

build_tools/pytorch.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414

1515
def install_requirements() -> List[str]:
1616
"""Install dependencies for TE/PyTorch extensions."""
17-
return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]
17+
return [
18+
"torch>=2.10",
19+
"einops",
20+
"onnxscript",
21+
"onnx",
22+
"packaging",
23+
"pydantic",
24+
"nvdlfw-inspect",
25+
]
1826

1927

2028
def test_requirements() -> List[str]:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# See LICENSE for license information.
44

55
[build-system]
6-
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
6+
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.10", "jax>=0.5.0", "flax>=0.7.1"]
77

88
# Use legacy backend to import local packages in setup.py
99
build-backend = "setuptools.build_meta:__legacy__"

transformer_engine/pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from transformer_engine.common import load_framework_extension
1414
from transformer_engine.pytorch.torch_version import torch_version
1515

16-
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
16+
assert torch_version() >= (2, 10), f"Minimum torch version 2.10 required. Found {torch_version()}."
1717

1818
load_framework_extension("torch")
1919
from transformer_engine.pytorch.module import LayerNormLinear

transformer_engine/pytorch/attention/dot_product_attention/softmax.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from typing import Callable, Tuple, Union, Optional
88
import torch
99
from torch import nn
10-
import transformer_engine_torch as tex
1110
from transformer_engine.pytorch.export import is_in_onnx_export_mode
1211

12+
_ops = torch.ops.transformer_engine
13+
1314

1415
THREADS_PER_WARP = 32
1516
THREADS_PER_BLOCK = 128
@@ -47,7 +48,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
4748
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
4849
"""ScaledUpperTriangMaskedSoftmax fwd"""
4950
scale_t = torch.tensor([scale])
50-
softmax_results = tex.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
51+
softmax_results = _ops.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
5152

5253
ctx.save_for_backward(softmax_results, scale_t)
5354
return softmax_results
@@ -56,7 +57,7 @@ def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
5657
def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
5758
"""ScaledUpperTriangMaskedSoftmax bwd"""
5859
softmax_results, scale_t = ctx.saved_tensors
59-
input_grads = tex.scaled_upper_triang_masked_softmax_backward(
60+
input_grads = _ops.scaled_upper_triang_masked_softmax_backward(
6061
output_grads, softmax_results, scale_t[0]
6162
)
6263

@@ -75,15 +76,15 @@ class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function):
7576
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
7677
"""ScaledAlignedCausalMaskedSoftmax fwd"""
7778
scale_t = torch.tensor([scale])
78-
softmax_results = tex.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0])
79+
softmax_results = _ops.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0])
7980
ctx.save_for_backward(softmax_results, scale_t)
8081
return softmax_results
8182

8283
@staticmethod
8384
def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
8485
"""ScaledAlignedCausalMaskedSoftmax bwd"""
8586
softmax_results, scale_t = ctx.saved_tensors
86-
input_grads = tex.scaled_aligned_causal_masked_softmax_backward(
87+
input_grads = _ops.scaled_aligned_causal_masked_softmax_backward(
8788
output_grads, softmax_results, scale_t[0]
8889
)
8990

@@ -103,7 +104,7 @@ def forward(ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float) -> torc
103104
"""ScaledMaskedSoftmax fwd"""
104105
scale_t = torch.tensor([scale])
105106

106-
softmax_results = tex.scaled_masked_softmax_forward(inputs, mask, scale_t[0])
107+
softmax_results = _ops.scaled_masked_softmax_forward(inputs, mask, scale_t[0])
107108
ctx.save_for_backward(softmax_results, scale_t)
108109
return softmax_results
109110

@@ -112,7 +113,7 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None]
112113
"""ScaledMaskedSoftmax bwd"""
113114
softmax_results, scale_t = ctx.saved_tensors
114115

115-
input_grads = tex.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
116+
input_grads = _ops.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
116117
return input_grads, None, None
117118

118119

@@ -128,7 +129,7 @@ def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
128129
"""ScaledSoftmax fwd"""
129130
scale_t = torch.tensor([scale])
130131

131-
softmax_results = tex.scaled_softmax_forward(inputs, scale_t[0])
132+
softmax_results = _ops.scaled_softmax_forward(inputs, scale_t[0])
132133
ctx.save_for_backward(softmax_results, scale_t)
133134
return softmax_results
134135

@@ -137,7 +138,7 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None]
137138
"""ScaledSoftmax bwd"""
138139
softmax_results, scale_t = ctx.saved_tensors
139140

140-
input_grads = tex.scaled_softmax_backward(output_grads, softmax_results, scale_t[0])
141+
input_grads = _ops.scaled_softmax_backward(output_grads, softmax_results, scale_t[0])
141142
return input_grads, None, None
142143

143144

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -349,32 +349,6 @@ py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask,
349349
const float dropout_probability,
350350
std::optional<at::Tensor> grad_input = std::nullopt);
351351

352-
/***************************************************************************************************
353-
* Softmax
354-
**************************************************************************************************/
355-
356-
at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor);
357-
358-
at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
359-
float scale_factor);
360-
361-
at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor);
362-
363-
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
364-
float scale_factor);
365-
366-
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor);
367-
368-
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
369-
at::Tensor softmax_results_,
370-
float scale_factor);
371-
372-
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor);
373-
374-
at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
375-
at::Tensor softmax_results_,
376-
float scale_factor);
377-
378352
/***************************************************************************************************
379353
* FP8 recipe
380354
**************************************************************************************************/

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -232,32 +232,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
232232
m.def("moe_unpermute_bwd", transformer_engine::pytorch::moe_unpermute_bwd, "MOE unpermute BWD",
233233
py::call_guard<py::gil_scoped_release>());
234234

235-
// Softmax functions
236-
m.def("scaled_softmax_forward", &transformer_engine::pytorch::scaled_softmax_forward,
237-
"Scaled Softmax FWD", py::call_guard<py::gil_scoped_release>());
238-
m.def("scaled_softmax_backward", &transformer_engine::pytorch::scaled_softmax_backward,
239-
"Scaled Softmax BWD", py::call_guard<py::gil_scoped_release>());
240-
m.def("scaled_masked_softmax_forward",
241-
&transformer_engine::pytorch::scaled_masked_softmax_forward, "Scaled Masked Softmax FWD",
242-
py::call_guard<py::gil_scoped_release>());
243-
m.def("scaled_masked_softmax_backward",
244-
&transformer_engine::pytorch::scaled_masked_softmax_backward, "Scaled Masked Softmax BWD",
245-
py::call_guard<py::gil_scoped_release>());
246-
m.def("scaled_upper_triang_masked_softmax_forward",
247-
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_forward,
248-
"Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
249-
m.def("scaled_upper_triang_masked_softmax_backward",
250-
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_backward,
251-
"Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
252-
m.def("scaled_aligned_causal_masked_softmax_forward",
253-
&transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_forward,
254-
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD",
255-
py::call_guard<py::gil_scoped_release>());
256-
m.def("scaled_aligned_causal_masked_softmax_backward",
257-
&transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_backward,
258-
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
259-
py::call_guard<py::gil_scoped_release>());
260-
261235
// Other granular functions
262236
m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"),
263237
py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#include "../stable_common.h"
8+
9+
// This file defines the transformer_engine library namespace.
10+
// All other stable ABI files use STABLE_TORCH_LIBRARY_FRAGMENT to add schemas
11+
// and STABLE_TORCH_LIBRARY_IMPL to add implementations.
12+
STABLE_TORCH_LIBRARY(transformer_engine, m) {
13+
// Softmax ops
14+
m.def("scaled_softmax_forward(Tensor input, float scale_factor) -> Tensor");
15+
m.def(
16+
"scaled_softmax_backward(Tensor output_grad, Tensor softmax_results, float scale_factor) -> "
17+
"Tensor");
18+
m.def("scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor");
19+
m.def(
20+
"scaled_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, float "
21+
"scale_factor) -> Tensor");
22+
m.def("scaled_upper_triang_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor");
23+
m.def(
24+
"scaled_upper_triang_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, "
25+
"float scale_factor) -> Tensor");
26+
m.def("scaled_aligned_causal_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor");
27+
m.def(
28+
"scaled_aligned_causal_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, "
29+
"float scale_factor) -> Tensor");
30+
}

0 commit comments

Comments
 (0)