Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .decompose_log_variants import DecomposeLogVariants
from .decompose_maxpool3d import DecomposeMaxPool3d
from .decompose_minmaxdim import DecomposeMinMaxDim
from .decompose_pad import DecomposePad
from .decompose_reciprocal import DecomposeReciprocal
from .decompose_remainder import DecomposeRemainder
from .decompose_roll import DecomposeRoll
Expand Down Expand Up @@ -80,6 +81,7 @@
DecomposeLogVariants,
DecomposeMaxPool3d,
DecomposeMinMaxDim,
DecomposePad,
DecomposeReciprocal,
DecomposeRemainder,
DecomposeRoll,
Expand Down
57 changes: 57 additions & 0 deletions backends/qualcomm/_passes/decompose_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult


class DecomposePad(ExportPass):
Comment thread
qti-horodnic marked this conversation as resolved.
"""
Convert aten.pad.default with non-constant modes to specific pad ops.
After torch.export, nn.ReflectionPad2d becomes aten.pad.default with mode='reflect'.
This pass converts it to aten.reflection_pad2d.default which the QNN pad builder handles directly.

Supported:
- mode='reflect', 4 padding values -> reflection_pad2d (QNN MIRROR_REFLECT, max rank 4).

Not supported by QNN (max rank 4 for non-constant schemes):
- mode='reflect', 6 padding values (3d) -> reflection_pad3d (QNN MIRROR_REFLECT max rank is 4)
- mode='replicate' -> QNN EDGE scheme produces incorrect results for FP32 inputs for replication_pad2d

Note: reflection_pad1d is handled by PyTorch's built-in decomposition of aten.pad.default (mode='reflect', 2 padding values)
-> reflection_pad1d, combined with the skip decomp table entry for reflection_pad1d.
"""

_PAD_TARGETS = {
torch.ops.aten.pad.default,
exir_ops.edge.aten.pad.default,
}

_PAD_OPS = {
("reflect", 4, False): torch.ops.aten.reflection_pad2d.default,
("reflect", 4, True): exir_ops.edge.aten.reflection_pad2d.default,
}

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
for node in list(graph.nodes):
if node.op != "call_function" or node.target not in self._PAD_TARGETS:
continue
mode = node.args[2] if len(node.args) > 2 else "constant"

padding = node.args[1]
is_edge = isinstance(node.target, EdgeOpOverload)
target_op = self._PAD_OPS.get((mode, len(padding), is_edge))
if target_op is None:
continue

node.target = target_op
node.args = (node.args[0], list(padding))

graph_module.recompile()
return PassResult(graph_module, True)
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.prelu.default,
exir_ops.edge.aten.reflection_pad1d.default,
exir_ops.edge.aten.reflection_pad2d.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.round.default,
Expand Down
4 changes: 4 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
DecomposeLogVariants,
DecomposeMaxPool3d,
DecomposeMinMaxDim,
DecomposePad,
DecomposeReciprocal,
DecomposeRemainder,
DecomposeRoll,
Expand Down Expand Up @@ -107,6 +108,7 @@ def get_capture_program_passes():
(DecomposeLogVariants, True),
(DecomposeMaxPool3d, True),
(DecomposeMinMaxDim, True),
(DecomposePad, True),
(DecomposeRemainder, True),
(DecomposeTrunc, True),
(ExpandBroadcastTensorShape, True),
Expand Down Expand Up @@ -227,6 +229,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeBinaryAlpha())
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeMaxPool3d(quantization_capture=True))
self.add_pass(DecomposePad())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
self.add_pass(DecomposeSilu())
Expand Down Expand Up @@ -254,6 +257,7 @@ def transform_for_export_pipeline(
):
self.add_pass(DecomposeBinaryAlpha())
self.add_pass(DecomposeCDist())
self.add_pass(DecomposePad())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
self.add_pass(DecomposeThreshold())
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_passes_dependency_for_capture_program():
DecomposeLinalgVectorNorm,
DecomposeLogVariants,
DecomposeMaxPool3d,
DecomposePad,
DecomposeRemainder,
DecomposeTrunc,
ExpandBroadcastTensorShape,
Expand Down Expand Up @@ -102,6 +103,7 @@ def get_passes_dependency_for_capture_program():
DecomposeLinalgVectorNorm: [RemoveRedundancy],
DecomposeLogVariants: [RemoveRedundancy],
DecomposeMaxPool3d: [RemoveRedundancy],
DecomposePad: [RemoveRedundancy],
DecomposeRemainder: [RemoveRedundancy],
DecomposeTrunc: [RemoveRedundancy],
ExpandBroadcastTensorShape: [FoldQDQ],
Expand Down
36 changes: 23 additions & 13 deletions backends/qualcomm/builders/op_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@

@register_node_visitor
class Pad(NodeVisitor):
target = ["aten.constant_pad_nd.default"]
target = [
"aten.constant_pad_nd.default",
"aten.reflection_pad1d.default",
"aten.reflection_pad2d.default",
]

_SCHEME_MAP = {
"aten.constant_pad_nd.default": OpPad.Scheme.CONSTANT,
"aten.reflection_pad1d.default": OpPad.Scheme.MIRROR_REFLECT,
"aten.reflection_pad2d.default": OpPad.Scheme.MIRROR_REFLECT,
}

def __init__(self, *args) -> None:
super().__init__(*args)
Expand All @@ -37,7 +47,6 @@ def define_node(
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
pad_input_tensors = [pad_inp_tensor_wrapper]

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
Expand All @@ -47,7 +56,6 @@ def define_node(
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
pad_output_tensors = [output_tensor_wrapper]

pad_amount_shape = [input_tensor.dim(), 2]
# pytorch padding start from the last index
Expand All @@ -62,28 +70,30 @@ def define_node(

if QCOM_AXIS_ORDER in node.meta:
pad_amount = pad_amount[list(node.meta[QCOM_AXIS_ORDER])]
pad_amount_val = node.args[2]

scheme = self._SCHEME_MAP[node.target.__name__]
pad_op = PyQnnManager.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpPad.op_name,
)
pad_op.AddInputTensors(pad_input_tensors)
pad_op.AddOutputTensors(pad_output_tensors)
pad_op.AddInputTensors([pad_inp_tensor_wrapper])
pad_op.AddOutputTensors([output_tensor_wrapper])

# For now, we only support constant (0) padding due to torch implementation
pad_op.AddScalarParam(
OpPad.param_scheme,
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{QCOM_DATA: np.uint32(OpPad.Scheme.CONSTANT)},
{QCOM_DATA: np.uint32(scheme)},
)

pad_op.AddScalarParam(
OpPad.param_pad_constant_value,
QNN_TENSOR_TYPE_MAP[type(pad_amount_val)],
{QCOM_DATA: pad_amount_val},
)
# pad_constant_value is only applicable for CONSTANT scheme, meaning the PAD op
if scheme == OpPad.Scheme.CONSTANT:
pad_amount_val = node.args[2]
pad_op.AddScalarParam(
OpPad.param_pad_constant_value,
QNN_TENSOR_TYPE_MAP[type(pad_amount_val)],
{QCOM_DATA: pad_amount_val},
)

pad_op.AddTensorParam(
OpPad.param_pad_amount,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
torch.ops.aten.pixel_shuffle.default,
torch.ops.aten.pixel_unshuffle.default,
torch.ops.aten.prelu.default,
torch.ops.aten.reflection_pad1d.default,
torch.ops.aten.reflection_pad2d.default,
torch.ops.aten.rms_norm.default,
torch.ops.aten._safe_softmax.default,
torch.ops.aten.stack.default,
Expand Down
9 changes: 8 additions & 1 deletion backends/qualcomm/quantizer/annotators/htp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,14 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)


@register_annotator([torch.ops.aten.pad.default], QnnConstants.OpPad.op_name)
@register_annotator(
[
torch.ops.aten.pad.default,
torch.ops.aten.reflection_pad1d.default,
torch.ops.aten.reflection_pad2d.default,
],
QnnConstants.OpPad.op_name,
)
class Pad(GeneralOpDef):
pass

Expand Down
9 changes: 8 additions & 1 deletion backends/qualcomm/quantizer/annotators/lpai_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,14 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)


@register_annotator([torch.ops.aten.pad.default], QnnConstants.OpPad.op_name)
@register_annotator(
[
torch.ops.aten.pad.default,
Comment thread
qti-horodnic marked this conversation as resolved.
torch.ops.aten.reflection_pad1d.default,
torch.ops.aten.reflection_pad2d.default,
],
QnnConstants.OpPad.op_name,
)
class Pad(GeneralOpDef):
pass

Expand Down
27 changes: 27 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,33 @@ def forward(self, x):
return torch.reciprocal(x)


class ReflectionPad1d(torch.nn.Module):
def __init__(self):
super().__init__()
self.pad = torch.nn.ReflectionPad1d(2)

def forward(self, x):
return self.pad(x)


class ReflectionPad2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.pad = torch.nn.ReflectionPad2d(2)

def forward(self, x):
return self.pad(x)


class ReflectionPad2dAsymmetric(torch.nn.Module):
def __init__(self):
super().__init__()
self.pad = torch.nn.ReflectionPad2d((1, 2, 3, 1))

def forward(self, x):
return self.pad(x)


class Relu(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
58 changes: 58 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,38 @@ def test_qnn_backend_reciprocal(self):
sample_input = (torch.randn([2, 2, 2, 2]),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_reflection_pad1d(self):
module = ReflectionPad1d() # noqa: F405
sample_inputs = [
(torch.randn(1, 3, 8),),
(torch.randn(1, 3, 8, dtype=torch.float16),),
]
for i, sample_input in enumerate(sample_inputs):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_reflection_pad2d(self):
test_comb = [
{
QCOM_MODULE: [
ReflectionPad2d(), # noqa: F405
ReflectionPad2dAsymmetric(), # noqa: F405
],
QCOM_SAMPLE_INPUTS: [
(torch.randn(1, 3, 8, 8),),
(torch.randn(1, 3, 8, 8, dtype=torch.float16),),
],
},
]

index = 0
for comb in test_comb:
for module in comb[QCOM_MODULE]:
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
with self.subTest(i=index):
index += 1
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_relu(self):
module = Relu() # noqa: F405
sample_input = (torch.randn([2, 5, 1, 3]),)
Expand Down Expand Up @@ -4243,6 +4275,32 @@ def test_qnn_backend_reciprocal(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_reflection_pad1d(self):
module = ReflectionPad1d() # noqa: F405
sample_input = (torch.randn(1, 3, 8),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_reflection_pad2d(self):
test_comb = [
{
QCOM_MODULE: [
ReflectionPad2d(), # noqa: F405
ReflectionPad2dAsymmetric(), # noqa: F405
],
QCOM_SAMPLE_INPUTS: [(torch.randn(1, 3, 8, 8),)],
},
]

index = 0
for comb in test_comb:
for module in comb[QCOM_MODULE]:
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
with self.subTest(i=index):
index += 1
qdq_module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(qdq_module, sample_input)

def test_qnn_backend_relu(self):
module = Relu() # noqa: F405
sample_input = (torch.randn([2, 5, 1, 3]),)
Expand Down
Loading