Fix ConvertMmToBmmPass for quantized (int8/int16) mm ops (#18974)#18974
Fix ConvertMmToBmmPass for quantized (int8/int16) mm ops (#18974)#18974apullin wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18974
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 Awaiting Approval, 2 New Failures, 5 Cancelled Jobs, 2 Pending, 1 Unrelated FailureAs of commit 82b6c7e with merge base 9207001 ( AWAITING APPROVAL - The following workflows need approval before CI can run:
NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
digantdesai
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
This PR needs a
|
316e474 to
7802809
Compare
Summary: This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD. The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace the graph on FakeTensors for shape propagation, but aten.bmm rejects int8/int16 FakeTensors, causing failures for any quantized mm ops. Since mm→bmm is a pure shape transformation (adding a batch dim of 1), we can set the output metadata directly: unsqueeze the mm's FakeTensor for the bmm node, and use the original for the squeeze. No need to re-execute the op. Reviewed By: digantdesai Differential Revision: D99857137
Summary: This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD. The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace the graph on FakeTensors for shape propagation, but aten.bmm rejects int8/int16 FakeTensors, causing failures for any quantized mm ops. Since mm→bmm is a pure shape transformation (adding a batch dim of 1), we can set the output metadata directly: unsqueeze the mm's FakeTensor for the bmm node, and use the original for the squeeze. No need to re-execute the op. Reviewed By: digantdesai Differential Revision: D99857137
b4a1625 to
5439a12
Compare
Summary: This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD. The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace the graph on FakeTensors for shape propagation, but aten.bmm rejects int8/int16 FakeTensors, causing failures for any quantized mm ops. Since mm→bmm is a pure shape transformation (adding a batch dim of 1), we can set the output metadata directly: unsqueeze the mm's FakeTensor for the bmm node, and use the original for the squeeze. No need to re-execute the op. Reviewed By: digantdesai Differential Revision: D99857137
|
|
5439a12 to
a28c6cc
Compare
Summary: This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD. The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace the graph on FakeTensors for shape propagation, but aten.bmm rejects int8/int16 FakeTensors, causing failures for any quantized mm ops. Since mm→bmm is a pure shape transformation (adding a batch dim of 1), we can set the output metadata directly: unsqueeze the mm's FakeTensor for the bmm node, and use the original for the squeeze. No need to re-execute the op. Reviewed By: digantdesai Differential Revision: D99857137
Summary: This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD. The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace the graph on FakeTensors for shape propagation, but aten.bmm rejects int8/int16 FakeTensors, causing failures for any quantized mm ops. Since mm→bmm is a pure shape transformation (adding a batch dim of 1), we can set the output metadata directly: unsqueeze the mm's FakeTensor for the bmm node, and use the original for the squeeze. No need to re-execute the op. Reviewed By: digantdesai Differential Revision: D99857137
a28c6cc to
82b6c7e
Compare
Summary:
This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD.
The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via
unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace
the graph on FakeTensors for shape propagation, but aten.bmm rejects
int8/int16 FakeTensors, causing failures for any quantized mm ops.
Since mm→bmm is a pure shape transformation (adding a batch dim of 1),
we can set the output metadata directly: unsqueeze the mm's FakeTensor
for the bmm node, and use the original for the squeeze. No need to
re-execute the op.
Reviewed By: digantdesai
Differential Revision: D99857137