Fix zero input shape for bgrad_group_quantize#2854
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryThis PR fixes a crash in Confidence Score: 5/5Safe to merge — all remaining findings are P2 style/coverage suggestions with no correctness impact. The fix is minimal and targeted: it replaces a hard error with a graceful zero-return for empty input. The logic is correct, the dbias dtype/device match the input, and a new test covers the fixed case. Only P2 observations remain (early-return placement and incomplete output assertion in the test), neither of which affects correctness. No files require special attention.
|
| Filename | Overview |
|---|---|
| transformer_engine/pytorch/csrc/extensions/cast.cpp | Removes the NVTE_CHECK that rejected zero-dimension inputs and instead adds an early-return path for empty tensors, returning a zero-filled dbias. The early return is placed after GroupedTensorWrapper construction and create_grouped_tensor, meaning those calls still execute on empty input. |
| tests/pytorch/test_grouped_tensor.py | Adds a new test that passes a (0, 1024) input to bgrad_group_quantize; verifies dbias shape and that all values are zero. Does not validate grouped_output shape/content in the zero-size case. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[bgrad_group_quantize called] --> B[Validate tensor is 2D]
B --> C[Extract logical_first_dim / logical_last_dim]
C --> D{empty_input_buffer?\nfirst_dim==0 or last_dim==0}
D -->|No - original path| E[Check MXFP8 quantizer]
E --> F[Build GroupedTensorWrapper for input]
F --> G[create_grouped_tensor → grouped_output_py]
G --> H[Build GroupedTensorWrapper for dbias]
H --> I[nvte_group_quantize_dbias workspace query]
I --> J[nvte_group_quantize_dbias execute]
J --> K[Return grouped_output_py + dbias_torch]
D -->|Yes - new early-return path| E2[Check MXFP8 quantizer]
E2 --> F2[Build GroupedTensorWrapper for input]
F2 --> G2[create_grouped_tensor → grouped_output_py]
G2 --> L[Return grouped_output_py +\nzeros dbias shape num_tensors×last_dim]
Reviews (1): Last reviewed commit: "Merge branch 'main' into bug_fix_zero_te..." | Re-trigger Greptile
| if (empty_input_buffer) { | ||
| at::Tensor dbias_torch = | ||
| at::zeros({static_cast<int64_t>(num_tensors), static_cast<int64_t>(logical_last_dim)}, | ||
| tensor.options()); | ||
| return py::make_tuple(py::reinterpret_borrow<py::object>(grouped_output_py), | ||
| py::cast(std::move(dbias_torch))); | ||
| } |
There was a problem hiding this comment.
Early return placed after expensive setup calls
The empty_input_buffer early return fires only after GroupedTensorWrapper construction, set_rowwise_data, and create_grouped_tensor have already been called on the empty tensor. If any of those operations have hidden allocations or non-trivial paths on zero-size input, the early return provides no protection. Consider moving the guard to immediately after the shape is known (before GroupedTensorWrapper grouped_input_tensor) and constructing a minimal empty grouped_output_py separately, so the fast path truly skips kernel-setup work.
| assert dbias.shape == (num_tensors, last_dim) | ||
| assert torch.all(dbias == 0) |
There was a problem hiding this comment.
grouped_output not validated in zero-size test
The test verifies dbias shape and values but makes no assertions on grouped_output. At minimum, checking that grouped_output is not None (and ideally that its underlying data tensor has shape (0, last_dim)) would ensure the output contract is correct for the empty path, not just the dbias.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: