Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/pytorch/test_grouped_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,27 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]], output_dbias
expected_dbias = torch.stack([t.sum(dim=0) for t in input_tensors])
assert torch.allclose(dbias, expected_dbias)

@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
def test_bgrad_group_quantize_zero_size_tensor(self) -> None:
"""Test bgrad_group_quantize handles zero-row input without error."""
num_tensors = 3
last_dim = 1024
grouped_input = torch.empty(0, last_dim, dtype=torch.bfloat16, device="cuda")

quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
quantizer.set_usage(rowwise=True, columnwise=False)
first_dims = torch.zeros(num_tensors, dtype=torch.int64, device="cuda")

grouped_output, dbias = tex.bgrad_group_quantize(
grouped_input,
quantizer,
num_tensors,
first_dims,
)

assert dbias.shape == (num_tensors, last_dim)
assert torch.all(dbias == 0)
Comment on lines +431 to +432
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 grouped_output not validated in zero-size test

The test verifies dbias shape and values but makes no assertions on grouped_output. At minimum, checking that grouped_output is not None (and ideally that its underlying data tensor has shape (0, last_dim)) would ensure the output contract is correct for the empty path, not just the dbias.


@pytest.mark.parametrize("output_dbias", [False, True])
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
def test_group_quantize_cudagraph_capturable(self, output_dbias: bool) -> None:
Expand Down
11 changes: 9 additions & 2 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer,
const auto logical_first_dim = logical_shape[0];
const auto logical_last_dim = logical_shape[1];

NVTE_CHECK(logical_first_dim > 0 && logical_last_dim > 0,
"bgrad_group_quantize: empty input tensor is not supported.");
bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0;

NVTE_CHECK(detail::IsMXFP8Quantizers(quantizer.ptr()),
"bgrad_group_quantize: only MXFP8 quantizer is supported.");
Expand All @@ -264,6 +263,14 @@ py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer,
py::reinterpret_borrow<py::object>(quantizer), first_dims, logical_first_dim,
logical_last_dim);

if (empty_input_buffer) {
at::Tensor dbias_torch =
at::zeros({static_cast<int64_t>(num_tensors), static_cast<int64_t>(logical_last_dim)},
tensor.options());
return py::make_tuple(py::reinterpret_borrow<py::object>(grouped_output_py),
py::cast(std::move(dbias_torch)));
}
Comment on lines +266 to +272
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Early return placed after expensive setup calls

The empty_input_buffer early return fires only after GroupedTensorWrapper construction, set_rowwise_data, and create_grouped_tensor have already been called on the empty tensor. If any of those operations have hidden allocations or non-trivial paths on zero-size input, the early return provides no protection. Consider moving the guard to immediately after the shape is known (before GroupedTensorWrapper grouped_input_tensor) and constructing a minimal empty grouped_output_py separately, so the fast path truly skips kernel-setup work.


const std::vector<size_t> dbias_logical_shape = {num_tensors, logical_last_dim};
GroupedTensorWrapper grouped_dbias(num_tensors, dbias_logical_shape, NVTE_DELAYED_TENSOR_SCALING);
at::Tensor dbias_torch =
Expand Down
Loading