Skip to content

Commit 59ab765

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5177c60 commit 59ab765

File tree

5 files changed

+11
-10
lines changed

5 files changed

+11
-10
lines changed

tests/pytorch/debug/run_distributed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
fp8_available = is_fp8_available()
4949

50+
5051
def _cmp_dist(ground_truth, output, parallel_mode):
5152
if parallel_mode == "column" and torch.cuda.get_device_capability() == (12, 0):
5253
# SM120: distributed column-parallel path may show a single-element

tests/pytorch/test_custom_recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_custom_recipe_grouped_linear_sanity():
121121
out_features = 64
122122
# Use 16-aligned splits on SM120 to satisfy FP8 GEMM leading-dimension requirements in backward.
123123
is_sm120 = torch.cuda.get_device_capability() == (12, 0)
124-
if is_sm120:
124+
if is_sm120:
125125
split_m = 16
126126
batch = num_gemms * split_m
127127
m_splits = [split_m] * num_gemms

transformer_engine/common/cast/dispatch/gated.cuh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp
4848
case NVTE_DELAYED_TENSOR_SCALING: {
4949
//const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
5050
// sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120
51-
// KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated -
51+
// KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated -
5252
// are there any forward only tests we'd like to keep enabled on sm120?
53-
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120();
53+
const bool use_tma_kernels =
54+
(cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120();
5455
if (use_tma_kernels) {
5556
Tensor dummy_grad_tensor;
5657
fp8::cast_gated_tma</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
@@ -143,7 +144,8 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
143144
case NVTE_DELAYED_TENSOR_SCALING: {
144145
//const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
145146
// sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120
146-
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120();
147+
const bool use_tma_kernels =
148+
(cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120();
147149
if (use_tma_kernels) {
148150
fp8::cast_gated_tma</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
149151
stream);

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,7 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) {
302302
inline void check_grouped_gemm_requirements(const char *api_name) {
303303
const int current_device = transformer_engine::cuda::current_device();
304304
const int sm_arch = transformer_engine::cuda::sm_arch(current_device);
305-
NVTE_CHECK(sm_arch >= 100, api_name,
306-
" requires Blackwell (SM100) or newer architecture.");
305+
NVTE_CHECK(sm_arch >= 100, api_name, " requires Blackwell (SM100) or newer architecture.");
307306
NVTE_CHECK(sm_arch != 120, api_name,
308307
" is currently unsupported on SM120. Grouped cuBLASLt GEMM heuristic selection "
309308
"returns CUBLAS_STATUS_NOT_SUPPORTED on this architecture (even with relaxed hints)");

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ inline bool is_sm120_device() {
8989
return device_prop.major == 12 && device_prop.minor == 0;
9090
}
9191

92-
9392
// helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy)
9493
void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor,
9594
GroupedTensorWrapper &grouped_output_tensor,
@@ -1192,9 +1191,9 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
11921191
auto rht_output_t = allocateTorchTensor(cols, rows, input_list[i].dtype());
11931192
rht_output_t_tensors.push_back(rht_output_t);
11941193
TensorWrapper rht_output_t_cpp;
1195-
rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input_list[i].dtype(),
1196-
std::vector<size_t>{static_cast<size_t>(cols),
1197-
static_cast<size_t>(rows)});
1194+
rht_output_t_cpp.set_rowwise_data(
1195+
rht_output_t.data_ptr(), input_list[i].dtype(),
1196+
std::vector<size_t>{static_cast<size_t>(cols), static_cast<size_t>(rows)});
11981197
nvte_hadamard_transform(input_list[i].data(), rht_output_t_cpp.data(), 0,
11991198
quantizer.rht_matrix_random_sign_mask_t, stream);
12001199
nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose_list[i].data(),

0 commit comments

Comments
 (0)