Use torch.Tag.pointwise to auto-discover permutable elementwise ops (#19457)#19457
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19457
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New Failures, 1 Pending, 4 Unrelated FailuresAs of commit 24dc09f with merge base f1062a7 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@mcremon-meta has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104612850. |
This PR needs a
|
…19457) Summary: Instead of manually listing each elementwise op in `RemovePermutesAroundElementwiseOps.permutable_ops`, use the `torch.Tag.pointwise` tag to automatically accept any ATen op tagged as pointwise. This makes the pass forward-compatible with new elementwise ops added to PyTorch without requiring manual updates. A new static method `_is_pointwise(target)` inspects the op's tags at runtime and is used as a fallback in both `is_node_permutable` and the main subgraph-discovery loop. The explicit `permutable_ops` set is retained for: - Non-pointwise ops that need special dimension-argument handling (`cat`, `mean`, `sum`, `slice_copy`) - `quantize_per_tensor`/`dequantize_per_tensor` ops (not part of ATen) Differential Revision: D104612850
7401610 to
24dc09f
Compare
|
cc @AdrianLundell, I don't anticipate any required changes for Arm's flow, the pass should just be able to catch more cases! |
AdrianLundell
left a comment
There was a problem hiding this comment.
Great work! This could probably be used in the view-fusing pass as well if you feel like adding that at the same time.
Summary:
Instead of manually listing each elementwise op in
RemovePermutesAroundElementwiseOps.permutable_ops, use thetorch.Tag.pointwisetag to automatically accept any ATen op tagged as pointwise. This makes the pass forward-compatible with new elementwise ops added to PyTorch without requiring manual updates.A new static method
_is_pointwise(target)inspects the op's tags at runtime and is used as a fallback in bothis_node_permutableand the main subgraph-discovery loop.The explicit
permutable_opsset is retained for:cat,mean,sum,slice_copy)quantize_per_tensor/dequantize_per_tensorops (not part of ATen)Differential Revision: D104612850