diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 5bc2faa007..04a0376019 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -410,6 +410,27 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]], output_dbias expected_dbias = torch.stack([t.sum(dim=0) for t in input_tensors]) assert torch.allclose(dbias, expected_dbias) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_bgrad_group_quantize_zero_size_tensor(self) -> None: + """Test bgrad_group_quantize handles zero-row input without error.""" + num_tensors = 3 + last_dim = 1024 + grouped_input = torch.empty(0, last_dim, dtype=torch.bfloat16, device="cuda") + + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.zeros(num_tensors, dtype=torch.int64, device="cuda") + + grouped_output, dbias = tex.bgrad_group_quantize( + grouped_input, + quantizer, + num_tensors, + first_dims, + ) + + assert dbias.shape == (num_tensors, last_dim) + assert torch.all(dbias == 0) + @pytest.mark.parametrize("output_dbias", [False, True]) @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) def test_group_quantize_cudagraph_capturable(self, output_dbias: bool) -> None: diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f150e90507..b689a1c1b4 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -247,8 +247,7 @@ py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, const auto logical_first_dim = logical_shape[0]; const auto logical_last_dim = logical_shape[1]; - NVTE_CHECK(logical_first_dim > 0 && logical_last_dim > 0, - "bgrad_group_quantize: empty input tensor is not supported."); + bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0; NVTE_CHECK(detail::IsMXFP8Quantizers(quantizer.ptr()), "bgrad_group_quantize: only MXFP8 quantizer is supported."); @@ -264,6 +263,14 @@ py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, py::reinterpret_borrow(quantizer), first_dims, logical_first_dim, logical_last_dim); + if (empty_input_buffer) { + at::Tensor dbias_torch = + at::zeros({static_cast(num_tensors), static_cast(logical_last_dim)}, + tensor.options()); + return py::make_tuple(py::reinterpret_borrow(grouped_output_py), + py::cast(std::move(dbias_torch))); + } + const std::vector dbias_logical_shape = {num_tensors, logical_last_dim}; GroupedTensorWrapper grouped_dbias(num_tensors, dbias_logical_shape, NVTE_DELAYED_TENSOR_SCALING); at::Tensor dbias_torch =