Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions backends/arm/_passes/mm_to_bmm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ def call(self, graph_module: torch.fx.GraphModule):
op="call_function", target=exir_ops.edge.aten.mm.default
)
for node in node_list:
mm_fake = get_first_fake_tensor(node)

# Unsqueeze input tensors to rank 3
for input_node in node.args:
if not isinstance(input_node, Node):
continue

shape = get_first_fake_tensor(input_node).shape
rank = len(shape)
input_fake = get_first_fake_tensor(input_node)
rank = len(input_fake.shape)
if rank != 2:
raise RuntimeError(f"Input tensor has rank {rank}, must be 2")

Expand All @@ -65,6 +67,7 @@ def call(self, graph_module: torch.fx.GraphModule):
input_node, # Input is node's original input
0,
)
unsqueeze_before.meta["val"] = input_fake.unsqueeze(0)
node.replace_input_with(input_node, unsqueeze_before)

# Replace mm node with bmm
Expand All @@ -76,10 +79,15 @@ def call(self, graph_module: torch.fx.GraphModule):
inherit_qparams=True,
)
bmm_node.args = node.args
# Manually set output meta: same as mm but with batch dim.
# This avoids re-executing bmm on FakeTensors, which fails
# for quantized (int8/int16) inputs since aten.bmm only
# supports float32 FakeTensor propagation.
bmm_node.meta["val"] = mm_fake.unsqueeze(0)
node.replace_all_uses_with(bmm_node)
graph.erase_node(node)

# Unsqueeze output tensor to rank 3
# Squeeze output tensor back to rank 2
with graph.inserting_after(bmm_node):
squeeze_after = create_node(
graph,
Expand All @@ -91,6 +99,7 @@ def call(self, graph_module: torch.fx.GraphModule):
bmm_node,
[0],
)
squeeze_after.meta["val"] = mm_fake
original_users = [
user for user in bmm_node.users if user != squeeze_after
]
Expand All @@ -101,6 +110,5 @@ def call(self, graph_module: torch.fx.GraphModule):

if modified_graph:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified_graph)
Loading