diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index 421f30960b6..13468dfd8d8 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -6,6 +6,7 @@ import copy import logging +from typing import List, Optional import torch from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -40,14 +41,19 @@ def call_operator(self, op, args, kwargs, meta): # new kwargs with dim_order, and no memory_format for the new op nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable - # get the target memory format for the EdgeOp - mem_format = nkwargs.pop("memory_format", torch.contiguous_format) + # Get the target memory format for the EdgeOp, defaulting to + # preserve_format (clone() with no memory_format kwarg preserves + # the input's layout instead of forcing contiguous). + mem_format = nkwargs.pop("memory_format", torch.preserve_format) - # can always get the shape, assuming rank is specialized + # Get input tensor and ndim + input_tensor: Optional[torch.Tensor] = None if isinstance(args[0], ProxyValue) and args[0].is_tensor(): - ndim = args[0].to_tensor().dim() + input_tensor = args[0].to_tensor() + ndim = input_tensor.dim() elif isinstance(args[0], torch.Tensor): - ndim = args[0].dim() + input_tensor = args[0] + ndim = input_tensor.dim() elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): ndim = len(args[0]) else: @@ -55,7 +61,21 @@ def call_operator(self, op, args, kwargs, meta): 0 ), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}" - nkwargs["dim_order"] = get_dim_order(mem_format, ndim) + # Derive dim_order based on memory format + dim_order: List[int] + if mem_format in (None, torch.preserve_format): + # preserve_format: inherit dim_order from input tensor + if input_tensor is not None: + dim_order = [int(d) for d in input_tensor.dim_order()] + else: + # Fallback to contiguous if no single input tensor is available + # (e.g. list inputs like torch.stack). + dim_order = list(range(ndim)) + else: + # Explicit memory format (contiguous_format, channels_last, etc.) + dim_order = get_dim_order(mem_format, ndim) + + nkwargs["dim_order"] = dim_order logger.debug( f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}." f" {DimOrderOpsMap[op].__name__} = dim_order: {nkwargs['dim_order']}" diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 6d4fbd37107..f683384f8f9 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -2502,6 +2502,161 @@ def test_convert_constant_dim_order_to_contiguous(self): ) +class TestMemoryFormatOpsPassPreserveFormat(unittest.TestCase): + """ + Tests for MemoryFormatOpsPass preserve_format semantics. + """ + + def test_clone_no_kwarg_preserves_channels_last_dim_order(self) -> None: + """ + Verify that clone() on a channels-last input with no memory_format kwarg + produces a _clone_dim_order node with channels-last dim_order (0,2,3,1). + """ + + class ConvClone(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1) + + def forward(self, x): + return self.conv(x).clone() + + model = ConvClone().to(memory_format=torch.channels_last) + x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last) + + # Run the model and verify that the output tensor preserves channels-last + # layout when no memory_format kwarg is provided. + with torch.no_grad(): + y = model(x) + self.assertTrue( + y.is_contiguous(memory_format=torch.channels_last), + f"clone() without memory_format kwarg should preserve channels-last layout, got strides {y.stride()}", + ) + + ep = torch.export.export(model, (x,)) + edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False)) + + # Find the _clone_dim_order node and check its dim_order + found_clone = False + for node in edge.exported_program().graph_module.graph.nodes: + if node.op == "call_function" and "_clone_dim_order" in str(node.target): + found_clone = True + spec = node.meta.get("val") + self.assertIsNotNone(spec, "Clone node should have meta['val']") + dim_order = tuple(spec.dim_order()) + self.assertEqual( + dim_order, + (0, 2, 3, 1), + f"Clone should preserve channels-last dim_order, got {dim_order}", + ) + break + + self.assertTrue(found_clone, "Should find a _clone_dim_order node in the graph") + + def test_clone_contiguous_format_kwarg_stays_contiguous(self) -> None: + """ + Regression guard: explicit contiguous_format should produce contiguous dim_order. + + Note: When clone(memory_format=contiguous_format) is called on a channels-last + input, this is a layout-transforming operation. After export, this typically + lowers to _to_dim_order_copy (not _clone_dim_order) because it changes the + memory layout. We check for both node types to be robust. + """ + + class CloneContiguousModel(torch.nn.Module): + def forward(self, x): + return x.clone(memory_format=torch.contiguous_format) + + model = CloneContiguousModel() + x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last) + + # Run the model and verify that the explicit contiguous_format kwarg + # produces a contiguous output layout (not channels-last). + with torch.no_grad(): + y = model(x) + self.assertTrue( + y.is_contiguous(), + f"clone(memory_format=contiguous_format) should produce contiguous layout, got strides {y.stride()}", + ) + self.assertFalse( + y.is_contiguous(memory_format=torch.channels_last), + "clone(memory_format=contiguous_format) should not preserve channels-last layout", + ) + + ep = torch.export.export(model, (x,)) + edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False)) + + # Find the dim_order copy node and check its dim_order. + # This may be _to_dim_order_copy (layout transform) or _clone_dim_order. + found_copy = False + for node in edge.exported_program().graph_module.graph.nodes: + if node.op == "call_function" and ( + "_clone_dim_order" in str(node.target) + or "_to_dim_order_copy" in str(node.target) + ): + found_copy = True + spec = node.meta.get("val") + self.assertIsNotNone(spec, "Copy node should have meta['val']") + dim_order = tuple(spec.dim_order()) + self.assertEqual( + dim_order, + (0, 1, 2, 3), + f"Explicit contiguous clone should have contiguous dim_order, got {dim_order}", + ) + break + + self.assertTrue( + found_copy, "Should find a _clone_dim_order or _to_dim_order_copy node" + ) + + def test_to_copy_no_kwarg_preserves_channels_last_dim_order(self) -> None: + """ + Verify that tensor.to(dtype=...) with no memory_format kwarg preserves + the input's dim_order (preserve_format semantics). + + This tests the _to_copy.default path in MemoryFormatOpsPass. + """ + + class ToCopyModel(torch.nn.Module): + def forward(self, x): + # .to(dtype=...) with no memory_format → preserve_format semantics + return x.to(dtype=torch.float32) + + model = ToCopyModel() + x = torch.randn(1, 3, 8, 8, dtype=torch.float16).to( + memory_format=torch.channels_last + ) + + # Run the model and verify that tensor.to(dtype=...) with no memory_format + # kwarg preserves channels-last layout on the output tensor. + with torch.no_grad(): + y = model(x) + self.assertTrue( + y.is_contiguous(memory_format=torch.channels_last), + f"to(dtype=...) without memory_format kwarg should preserve channels-last layout, got strides {y.stride()}", + ) + + ep = torch.export.export(model, (x,)) + edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False)) + + # Find the _to_dim_order_copy node and verify it preserves channels-last + found_copy = False + for node in edge.exported_program().graph_module.graph.nodes: + if node.op == "call_function" and "_to_dim_order_copy" in str(node.target): + found_copy = True + spec = node.meta.get("val") + self.assertIsNotNone(spec, "Copy node should have meta['val']") + dim_order = tuple(spec.dim_order()) + self.assertEqual( + dim_order, + (0, 2, 3, 1), + f"to(dtype=...) should preserve channels-last dim_order, got {dim_order}", + ) + break + + self.assertTrue(found_copy, "Should find a _to_dim_order_copy node") + + class TestCSEPass(unittest.TestCase): """Tests for Common Subexpression Elimination pass."""