diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 72df4240eb8..d64612e02af 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -26,6 +26,7 @@ from .decompose_log_variants import DecomposeLogVariants from .decompose_maxpool3d import DecomposeMaxPool3d from .decompose_minmaxdim import DecomposeMinMaxDim +from .decompose_pad import DecomposePad from .decompose_reciprocal import DecomposeReciprocal from .decompose_remainder import DecomposeRemainder from .decompose_roll import DecomposeRoll @@ -80,6 +81,7 @@ DecomposeLogVariants, DecomposeMaxPool3d, DecomposeMinMaxDim, + DecomposePad, DecomposeReciprocal, DecomposeRemainder, DecomposeRoll, diff --git a/backends/qualcomm/_passes/decompose_pad.py b/backends/qualcomm/_passes/decompose_pad.py new file mode 100644 index 00000000000..ebfa8fd1a8f --- /dev/null +++ b/backends/qualcomm/_passes/decompose_pad.py @@ -0,0 +1,57 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult + + +class DecomposePad(ExportPass): + """ + Convert aten.pad.default with non-constant modes to specific pad ops. + After torch.export, nn.ReflectionPad2d becomes aten.pad.default with mode='reflect'. + This pass converts it to aten.reflection_pad2d.default which the QNN pad builder handles directly. + + Supported: + - mode='reflect', 4 padding values -> reflection_pad2d (QNN MIRROR_REFLECT, max rank 4). + + Not supported by QNN (max rank 4 for non-constant schemes): + - mode='reflect', 6 padding values (3d) -> reflection_pad3d (QNN MIRROR_REFLECT max rank is 4) + - mode='replicate' -> QNN EDGE scheme produces incorrect results for FP32 inputs for replication_pad2d + + Note: reflection_pad1d is handled by PyTorch's built-in decomposition of aten.pad.default (mode='reflect', 2 padding values) + -> reflection_pad1d, combined with the skip decomp table entry for reflection_pad1d. + """ + + _PAD_TARGETS = { + torch.ops.aten.pad.default, + exir_ops.edge.aten.pad.default, + } + + _PAD_OPS = { + ("reflect", 4, False): torch.ops.aten.reflection_pad2d.default, + ("reflect", 4, True): exir_ops.edge.aten.reflection_pad2d.default, + } + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + for node in list(graph.nodes): + if node.op != "call_function" or node.target not in self._PAD_TARGETS: + continue + mode = node.args[2] if len(node.args) > 2 else "constant" + + padding = node.args[1] + is_edge = isinstance(node.target, EdgeOpOverload) + target_op = self._PAD_OPS.get((mode, len(padding), is_edge)) + if target_op is None: + continue + + node.target = target_op + node.args = (node.args[0], list(padding)) + + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 71187ad454d..9422051addd 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -113,6 +113,8 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.neg.default, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.prelu.default, + exir_ops.edge.aten.reflection_pad1d.default, + exir_ops.edge.aten.reflection_pad2d.default, exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten.round.default, diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index ca7faa244cf..71b36cd746a 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -31,6 +31,7 @@ DecomposeLogVariants, DecomposeMaxPool3d, DecomposeMinMaxDim, + DecomposePad, DecomposeReciprocal, DecomposeRemainder, DecomposeRoll, @@ -107,6 +108,7 @@ def get_capture_program_passes(): (DecomposeLogVariants, True), (DecomposeMaxPool3d, True), (DecomposeMinMaxDim, True), + (DecomposePad, True), (DecomposeRemainder, True), (DecomposeTrunc, True), (ExpandBroadcastTensorShape, True), @@ -227,6 +229,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeBinaryAlpha()) self.add_pass(DecomposeCDist()) self.add_pass(DecomposeMaxPool3d(quantization_capture=True)) + self.add_pass(DecomposePad()) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) self.add_pass(DecomposeSilu()) @@ -254,6 +257,7 @@ def transform_for_export_pipeline( ): self.add_pass(DecomposeBinaryAlpha()) self.add_pass(DecomposeCDist()) + self.add_pass(DecomposePad()) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) self.add_pass(DecomposeThreshold()) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 24835702ce4..548d91219ef 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -71,6 +71,7 @@ def get_passes_dependency_for_capture_program(): DecomposeLinalgVectorNorm, DecomposeLogVariants, DecomposeMaxPool3d, + DecomposePad, DecomposeRemainder, DecomposeTrunc, ExpandBroadcastTensorShape, @@ -102,6 +103,7 @@ def get_passes_dependency_for_capture_program(): DecomposeLinalgVectorNorm: [RemoveRedundancy], DecomposeLogVariants: [RemoveRedundancy], DecomposeMaxPool3d: [RemoveRedundancy], + DecomposePad: [RemoveRedundancy], DecomposeRemainder: [RemoveRedundancy], DecomposeTrunc: [RemoveRedundancy], ExpandBroadcastTensorShape: [FoldQDQ], diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py index 2302d6d55a0..28a3dd54d8f 100644 --- a/backends/qualcomm/builders/op_pad.py +++ b/backends/qualcomm/builders/op_pad.py @@ -18,7 +18,17 @@ @register_node_visitor class Pad(NodeVisitor): - target = ["aten.constant_pad_nd.default"] + target = [ + "aten.constant_pad_nd.default", + "aten.reflection_pad1d.default", + "aten.reflection_pad2d.default", + ] + + _SCHEME_MAP = { + "aten.constant_pad_nd.default": OpPad.Scheme.CONSTANT, + "aten.reflection_pad1d.default": OpPad.Scheme.MIRROR_REFLECT, + "aten.reflection_pad2d.default": OpPad.Scheme.MIRROR_REFLECT, + } def __init__(self, *args) -> None: super().__init__(*args) @@ -37,7 +47,6 @@ def define_node( PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - pad_input_tensors = [pad_inp_tensor_wrapper] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( @@ -47,7 +56,6 @@ def define_node( PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - pad_output_tensors = [output_tensor_wrapper] pad_amount_shape = [input_tensor.dim(), 2] # pytorch padding start from the last index @@ -62,28 +70,30 @@ def define_node( if QCOM_AXIS_ORDER in node.meta: pad_amount = pad_amount[list(node.meta[QCOM_AXIS_ORDER])] - pad_amount_val = node.args[2] + scheme = self._SCHEME_MAP[node.target.__name__] pad_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpPad.op_name, ) - pad_op.AddInputTensors(pad_input_tensors) - pad_op.AddOutputTensors(pad_output_tensors) + pad_op.AddInputTensors([pad_inp_tensor_wrapper]) + pad_op.AddOutputTensors([output_tensor_wrapper]) - # For now, we only support constant (0) padding due to torch implementation pad_op.AddScalarParam( OpPad.param_scheme, PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {QCOM_DATA: np.uint32(OpPad.Scheme.CONSTANT)}, + {QCOM_DATA: np.uint32(scheme)}, ) - pad_op.AddScalarParam( - OpPad.param_pad_constant_value, - QNN_TENSOR_TYPE_MAP[type(pad_amount_val)], - {QCOM_DATA: pad_amount_val}, - ) + # pad_constant_value is only applicable for CONSTANT scheme, meaning the PAD op + if scheme == OpPad.Scheme.CONSTANT: + pad_amount_val = node.args[2] + pad_op.AddScalarParam( + OpPad.param_pad_constant_value, + QNN_TENSOR_TYPE_MAP[type(pad_amount_val)], + {QCOM_DATA: pad_amount_val}, + ) pad_op.AddTensorParam( OpPad.param_pad_amount, diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index bcfe41c2bbe..a83444a56b2 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -64,6 +64,8 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.pixel_shuffle.default, torch.ops.aten.pixel_unshuffle.default, torch.ops.aten.prelu.default, + torch.ops.aten.reflection_pad1d.default, + torch.ops.aten.reflection_pad2d.default, torch.ops.aten.rms_norm.default, torch.ops.aten._safe_softmax.default, torch.ops.aten.stack.default, diff --git a/backends/qualcomm/quantizer/annotators/htp_rules.py b/backends/qualcomm/quantizer/annotators/htp_rules.py index ce01fceca80..cd65d02c752 100644 --- a/backends/qualcomm/quantizer/annotators/htp_rules.py +++ b/backends/qualcomm/quantizer/annotators/htp_rules.py @@ -1171,7 +1171,14 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) -@register_annotator([torch.ops.aten.pad.default], QnnConstants.OpPad.op_name) +@register_annotator( + [ + torch.ops.aten.pad.default, + torch.ops.aten.reflection_pad1d.default, + torch.ops.aten.reflection_pad2d.default, + ], + QnnConstants.OpPad.op_name, +) class Pad(GeneralOpDef): pass diff --git a/backends/qualcomm/quantizer/annotators/lpai_rules.py b/backends/qualcomm/quantizer/annotators/lpai_rules.py index dc6d727f0d4..60cebfcc5c0 100644 --- a/backends/qualcomm/quantizer/annotators/lpai_rules.py +++ b/backends/qualcomm/quantizer/annotators/lpai_rules.py @@ -673,7 +673,14 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) -@register_annotator([torch.ops.aten.pad.default], QnnConstants.OpPad.op_name) +@register_annotator( + [ + torch.ops.aten.pad.default, + torch.ops.aten.reflection_pad1d.default, + torch.ops.aten.reflection_pad2d.default, + ], + QnnConstants.OpPad.op_name, +) class Pad(GeneralOpDef): pass diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 053e5d26455..05745e8bd7b 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1863,6 +1863,33 @@ def forward(self, x): return torch.reciprocal(x) +class ReflectionPad1d(torch.nn.Module): + def __init__(self): + super().__init__() + self.pad = torch.nn.ReflectionPad1d(2) + + def forward(self, x): + return self.pad(x) + + +class ReflectionPad2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.pad = torch.nn.ReflectionPad2d(2) + + def forward(self, x): + return self.pad(x) + + +class ReflectionPad2dAsymmetric(torch.nn.Module): + def __init__(self): + super().__init__() + self.pad = torch.nn.ReflectionPad2d((1, 2, 3, 1)) + + def forward(self, x): + return self.pad(x) + + class Relu(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index bcfeb13c0bf..a256a9855eb 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1756,6 +1756,38 @@ def test_qnn_backend_reciprocal(self): sample_input = (torch.randn([2, 2, 2, 2]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_reflection_pad1d(self): + module = ReflectionPad1d() # noqa: F405 + sample_inputs = [ + (torch.randn(1, 3, 8),), + (torch.randn(1, 3, 8, dtype=torch.float16),), + ] + for i, sample_input in enumerate(sample_inputs): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_reflection_pad2d(self): + test_comb = [ + { + QCOM_MODULE: [ + ReflectionPad2d(), # noqa: F405 + ReflectionPad2dAsymmetric(), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 3, 8, 8),), + (torch.randn(1, 3, 8, 8, dtype=torch.float16),), + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_relu(self): module = Relu() # noqa: F405 sample_input = (torch.randn([2, 5, 1, 3]),) @@ -4243,6 +4275,32 @@ def test_qnn_backend_reciprocal(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_reflection_pad1d(self): + module = ReflectionPad1d() # noqa: F405 + sample_input = (torch.randn(1, 3, 8),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_reflection_pad2d(self): + test_comb = [ + { + QCOM_MODULE: [ + ReflectionPad2d(), # noqa: F405 + ReflectionPad2dAsymmetric(), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [(torch.randn(1, 3, 8, 8),)], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) + def test_qnn_backend_relu(self): module = Relu() # noqa: F405 sample_input = (torch.randn([2, 5, 1, 3]),)