-
Notifications
You must be signed in to change notification settings - Fork 690
Fix zero input shape for bgrad_group_quantize #2854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -247,8 +247,7 @@ py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, | |
| const auto logical_first_dim = logical_shape[0]; | ||
| const auto logical_last_dim = logical_shape[1]; | ||
|
|
||
| NVTE_CHECK(logical_first_dim > 0 && logical_last_dim > 0, | ||
| "bgrad_group_quantize: empty input tensor is not supported."); | ||
| bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0; | ||
|
|
||
| NVTE_CHECK(detail::IsMXFP8Quantizers(quantizer.ptr()), | ||
| "bgrad_group_quantize: only MXFP8 quantizer is supported."); | ||
|
|
@@ -264,6 +263,14 @@ py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, | |
| py::reinterpret_borrow<py::object>(quantizer), first_dims, logical_first_dim, | ||
| logical_last_dim); | ||
|
|
||
| if (empty_input_buffer) { | ||
| at::Tensor dbias_torch = | ||
| at::zeros({static_cast<int64_t>(num_tensors), static_cast<int64_t>(logical_last_dim)}, | ||
| tensor.options()); | ||
| return py::make_tuple(py::reinterpret_borrow<py::object>(grouped_output_py), | ||
| py::cast(std::move(dbias_torch))); | ||
| } | ||
|
Comment on lines
+266
to
+272
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The |
||
|
|
||
| const std::vector<size_t> dbias_logical_shape = {num_tensors, logical_last_dim}; | ||
| GroupedTensorWrapper grouped_dbias(num_tensors, dbias_logical_shape, NVTE_DELAYED_TENSOR_SCALING); | ||
| at::Tensor dbias_torch = | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test verifies
dbiasshape and values but makes no assertions ongrouped_output. At minimum, checking thatgrouped_outputis notNone(and ideally that its underlying data tensor has shape(0, last_dim)) would ensure the output contract is correct for the empty path, not just the dbias.