From ee9a981bfe10a62f191fc0c98ff8651595718333 Mon Sep 17 00:00:00 2001 From: KevinJie Date: Sun, 19 Apr 2026 07:22:55 +0000 Subject: [PATCH 1/6] Fix AttributeError in _mark_nodes_as_annotated when node is None --- backends/qualcomm/quantizer/rules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backends/qualcomm/quantizer/rules.py b/backends/qualcomm/quantizer/rules.py index 878acfea422..77989f8e874 100644 --- a/backends/qualcomm/quantizer/rules.py +++ b/backends/qualcomm/quantizer/rules.py @@ -29,6 +29,8 @@ def _mark_nodes_as_annotated(nodes: List[Node]): for node in nodes: + if node is None: + continue if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() node.meta[Q_ANNOTATION_KEY]._annotated = True From 40404fdd974e75deec9afb0e299bdbe2cf9a04fc Mon Sep 17 00:00:00 2001 From: KevinJie Date: Mon, 20 Apr 2026 15:15:44 +0000 Subject: [PATCH 2/6] [QNN] Guard get_parameter against node=None in LayerNormVisitor Fixes AttributeError when aten.native_layer_norm has optional weight=None. Both weight and bias are guarded to handle the None case gracefully. Co-Authored-By: Claude Opus 4.6 --- backends/qualcomm/builders/op_layer_norm.py | 26 +++++++++++++-------- backends/qualcomm/builders/utils.py | 4 +++- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index a51056eb7bb..9882c77cc7f 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -55,16 +55,22 @@ def define_node( axis_shape = [len(axis)] weight_node = self.get_node(node.args[2]) - weight_tensor = get_parameter(weight_node, self.edge_program) - weight_tensor_wrapper = self.define_tensor( - weight_node, - node, - weight_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - - layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper] + if weight_node is not None: + weight_tensor = get_parameter(weight_node, self.edge_program) + weight_tensor_wrapper = self.define_tensor( + weight_node, + node, + weight_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper] + else: + warnings.warn( + "[QNN Delegate Op Builder]: LayerNorm weight is None, skipping", + stacklevel=1, + ) + layer_norm_input_tensors = [input_tensor_wrapper] bias_node = self.get_node(node.args[3]) if bias_node is not None: diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index 3345f2e1fc9..acf4400f212 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -29,7 +29,9 @@ def is_parameter( def get_parameter( node: torch.fx.Node, edge_program: torch.export.ExportedProgram -) -> torch.Tensor: +) -> Optional[torch.Tensor]: + if node is None: + return None param = None if is_param(edge_program, node): param = get_param(edge_program, node) From 5beaa5774c7f782e2a7e5edaa3bfb56a3d6ef1e2 Mon Sep 17 00:00:00 2001 From: KevinJie Date: Wed, 22 Apr 2026 03:08:10 +0000 Subject: [PATCH 3/6] [Qualcomm] Support native_layer_norm and affine-free LayerNorm in QNN backend - add QNN layer norm support for aten.native_layer_norm.default - handle missing weight/bias by creating identity weight and zero bias - always provide bias tensor for QNN LayerNorm op - add floating-point and quantized tests for native_layer_norm - print generated pte filename after export --- backends/qualcomm/builders/op_layer_norm.py | 84 +++++++++++++------ backends/qualcomm/export_utils.py | 1 + .../quantizer/annotators/htp_rules.py | 3 +- backends/qualcomm/tests/models.py | 20 +++++ backends/qualcomm/tests/test_qnn_delegate.py | 15 ++++ 5 files changed, 96 insertions(+), 27 deletions(-) diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index 9882c77cc7f..d4ca7e4589a 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -11,7 +11,12 @@ import numpy as np import torch -from executorch.backends.qualcomm.utils.constants import QCOM_DATA +from executorch.backends.qualcomm.utils.constants import ( + QCOM_DATA, + QCOM_QUANT_ATTRS, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import NodeVisitor from .node_visitor_manager import register_node_visitor @@ -31,6 +36,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], ) -> PyQnnManager.PyQnnOpWrapper: + # args of node : ['input', 'normalized_shape', 'weight', 'bias', 'eps'] input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( @@ -54,37 +60,61 @@ def define_node( axis = [len(input_tensor.shape) - 1] axis_shape = [len(axis)] - weight_node = self.get_node(node.args[2]) - if weight_node is not None: + has_weight = len(node.args) > 2 and node.args[2] is not None + if has_weight: + weight_node = self.get_node(node.args[2]) weight_tensor = get_parameter(weight_node, self.edge_program) - weight_tensor_wrapper = self.define_tensor( - weight_node, - node, - weight_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper] else: - warnings.warn( - "[QNN Delegate Op Builder]: LayerNorm weight is None, skipping", - stacklevel=1, + # elementwise_affine=False: use all-ones weight as identity + weight_tensor = torch.ones(normalized_shapes, dtype=torch.float32) + weight_node = torch.fx.Node( + node.graph, + node.name + "_runtime_weight", + "call_function", + exir_ops.edge.aten.tensor.default, + (), + {}, ) - layer_norm_input_tensors = [input_tensor_wrapper] + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_attrs[QCOM_ZERO_POINT] = 0 + weight_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + weight_tensor_wrapper = self.define_tensor( + weight_node, + node, + weight_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) - bias_node = self.get_node(node.args[3]) - if bias_node is not None: + # Fake node: even when original bias is absent, QNN still needs it + has_bias = len(node.args) > 3 and node.args[3] is not None + if has_bias: + bias_node = self.get_node(node.args[3]) bias_tensor = get_parameter(bias_node, self.edge_program) - bias_tensor_wrapper = self.define_tensor( - bias_node, - node, - bias_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, + else: + bias_tensor = torch.zeros(normalized_shapes, dtype=torch.float32) + bias_node = torch.fx.Node( + node.graph, + node.name + "_runtime_bias", + "call_function", + exir_ops.edge.aten.tensor.default, + (), + {}, ) - layer_norm_input_tensors.append(bias_tensor_wrapper) + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_attrs[QCOM_ZERO_POINT] = 0 + bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + bias_tensor_wrapper = self.define_tensor( + bias_node, + node, + bias_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) - epsilon = node.args[4] + epsilon = node.args[4] if len(node.args) > 4 else 1e-05 output_tensor = self.get_tensor(node, node, 0) output_tensor_wrapper = self.define_tensor( @@ -100,7 +130,9 @@ def define_node( QNN_OP_PACKAGE_NAME_QTI_AISW, OpLayerNorm.op_name, ) - layer_norm_op.AddInputTensors(layer_norm_input_tensors) + layer_norm_op.AddInputTensors( + [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] + ) layer_norm_op.AddOutputTensors([output_tensor_wrapper]) layer_norm_op.AddScalarParam( OpLayerNorm.param_epsilon, diff --git a/backends/qualcomm/export_utils.py b/backends/qualcomm/export_utils.py index 2c7ab2abd02..f66bf9d5858 100644 --- a/backends/qualcomm/export_utils.py +++ b/backends/qualcomm/export_utils.py @@ -617,6 +617,7 @@ def build_executorch_binary( with open(pte_name, "wb") as file: exec_prog_mgr.write_to_file(file) + print(f"Successfully generated {pte_name}.") if qnn_config.compile_only: sys.exit(0) diff --git a/backends/qualcomm/quantizer/annotators/htp_rules.py b/backends/qualcomm/quantizer/annotators/htp_rules.py index ce01fceca80..5b763bf9b11 100644 --- a/backends/qualcomm/quantizer/annotators/htp_rules.py +++ b/backends/qualcomm/quantizer/annotators/htp_rules.py @@ -828,7 +828,8 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator( - [torch.ops.aten.layer_norm.default], QnnConstants.OpLayerNorm.op_name + [torch.ops.aten.layer_norm.default, torch.ops.aten.native_layer_norm.default], + QnnConstants.OpLayerNorm.op_name, ) class LayerNorm(GeneralOpDef): @staticmethod diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 053e5d26455..d5cb42d72f2 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1388,6 +1388,26 @@ def forward(self, x): return self.linear(self.layer_norm(x)) +class NativeLayerNorm(torch.nn.Module): + def __init__(self, affine=True): + super().__init__() + self.affine = affine + self.weight = torch.nn.Parameter(torch.ones(768)) + self.bias = torch.nn.Parameter(torch.zeros(768)) + self.normalized_shape = [768] + self.eps = 1e-6 + + def forward(self, x): + if self.affine: + return torch.native_layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + )[0] + else: + return torch.native_layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + )[0] + + class LayerNormAdd(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 48f07da06e9..ef1a61715e8 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1384,6 +1384,13 @@ def test_qnn_backend_layer_norm(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_native_layer_norm(self): + modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405 + sample_input = (torch.randn(196, 768),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_leaky_relu(self): torch.manual_seed(8) test_comb = [ @@ -3811,6 +3818,14 @@ def test_qnn_backend_layer_norm(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_native_layer_norm(self): + modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405 + sample_input = (torch.randn(196, 768),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_leaky_relu(self): test_comb = [ { From 0a2c42cc49b3bbc883712ff249715b4052935d01 Mon Sep 17 00:00:00 2001 From: KevinJie Date: Fri, 24 Apr 2026 01:41:36 +0000 Subject: [PATCH 4/6] Fix QNN LayerNorm optional arg annotation --- .../quantizer/annotators/htp_rules.py | 36 +++++++++---------- .../quantizer/annotators/lpai_rules.py | 36 +++++++++---------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/backends/qualcomm/quantizer/annotators/htp_rules.py b/backends/qualcomm/quantizer/annotators/htp_rules.py index 5b763bf9b11..57c135668b5 100644 --- a/backends/qualcomm/quantizer/annotators/htp_rules.py +++ b/backends/qualcomm/quantizer/annotators/htp_rules.py @@ -835,10 +835,8 @@ class LayerNorm(GeneralOpDef): @staticmethod def annotate(node: Node, quantization_config: QuantizationConfig) -> None: act_node = node.args[0] - weight_node = node.args[2] - bias_node = None - if len(node.args) > 2: - bias_node = node.args[3] + weight_node = node.args[2] if len(node.args) > 2 else None + bias_node = node.args[3] if len(node.args) > 3 else None if _is_annotated([node]): return @@ -849,20 +847,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: act_node, input_act_qspec, ) - if input_act_qspec.dtype == torch.int32: - annotate_input_qspec_map( - node, - weight_node, - get_16a16w_qnn_ptq_config().weight, - ) - else: - annotate_input_qspec_map( - node, - weight_node, - input_act_qspec, - ) - nodes_to_mark_annotated = [node, weight_node] - if bias_node: + nodes_to_mark_annotated = [node] + if isinstance(weight_node, Node): + if input_act_qspec.dtype == torch.int32: + annotate_input_qspec_map( + node, + weight_node, + get_16a16w_qnn_ptq_config().weight, + ) + else: + annotate_input_qspec_map( + node, + weight_node, + input_act_qspec, + ) + nodes_to_mark_annotated.append(weight_node) + if isinstance(bias_node, Node): annotate_input_qspec_map( node, bias_node, diff --git a/backends/qualcomm/quantizer/annotators/lpai_rules.py b/backends/qualcomm/quantizer/annotators/lpai_rules.py index dc6d727f0d4..e9b6ab2995c 100644 --- a/backends/qualcomm/quantizer/annotators/lpai_rules.py +++ b/backends/qualcomm/quantizer/annotators/lpai_rules.py @@ -475,10 +475,8 @@ class LayerNorm(GeneralOpDef): @staticmethod def annotate(node: Node, quantization_config: QuantizationConfig) -> None: act_node = node.args[0] - weight_node = node.args[2] - bias_node = None - if len(node.args) > 2: - bias_node = node.args[3] + weight_node = node.args[2] if len(node.args) > 2 else None + bias_node = node.args[3] if len(node.args) > 3 else None if _is_annotated([node]): return @@ -489,20 +487,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: act_node, input_act_qspec, ) - if input_act_qspec.dtype == torch.int32: - annotate_input_qspec_map( - node, - weight_node, - get_16a16w_qnn_ptq_config().weight, - ) - else: - annotate_input_qspec_map( - node, - weight_node, - input_act_qspec, - ) - nodes_to_mark_annotated = [node, weight_node] - if bias_node: + nodes_to_mark_annotated = [node] + if isinstance(weight_node, Node): + if input_act_qspec.dtype == torch.int32: + annotate_input_qspec_map( + node, + weight_node, + get_16a16w_qnn_ptq_config().weight, + ) + else: + annotate_input_qspec_map( + node, + weight_node, + input_act_qspec, + ) + nodes_to_mark_annotated.append(weight_node) + if isinstance(bias_node, Node): annotate_input_qspec_map( node, bias_node, From 31ea42641f23c40a5cc31354904bad658e4cfd7c Mon Sep 17 00:00:00 2001 From: KevinJie Date: Mon, 27 Apr 2026 17:28:30 +0000 Subject: [PATCH 5/6] Address QNN LayerNorm optional arg review --- .../qualcomm/_passes/annotate_quant_attrs.py | 4 +-- backends/qualcomm/builders/op_layer_norm.py | 36 +++++++------------ backends/qualcomm/builders/utils.py | 14 ++++---- backends/qualcomm/quantizer/rules.py | 2 -- backends/qualcomm/tests/models.py | 29 ++++----------- backends/qualcomm/tests/test_qnn_delegate.py | 27 ++++++-------- 6 files changed, 38 insertions(+), 74 deletions(-) diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index 6077d51b099..7ac7e094a97 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -8,7 +8,7 @@ import torch from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops -from executorch.backends.qualcomm.builders.utils import get_parameter +from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter from executorch.backends.qualcomm.utils.constants import ( QCOM_DTYPE, QCOM_ENCODING, @@ -130,7 +130,7 @@ def _annotate_quant_attrs( self._annotate_requant(n) # With fold_quant enabled, check if the input of dq op is quantized param. param = None - if n.target in dq_ops: + if n.target in dq_ops and is_parameter(n.args[0], self.edge_program): param = get_parameter(n.args[0], self.edge_program) if n.target not in q_ops and param is None: continue diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index d4ca7e4589a..abb6bfd657f 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -63,6 +63,7 @@ def define_node( has_weight = len(node.args) > 2 and node.args[2] is not None if has_weight: weight_node = self.get_node(node.args[2]) + assert weight_node is not None weight_tensor = get_parameter(weight_node, self.edge_program) else: # elementwise_affine=False: use all-ones weight as identity @@ -87,32 +88,21 @@ def define_node( nodes_to_wrappers, ) - # Fake node: even when original bias is absent, QNN still needs it + layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper] + has_bias = len(node.args) > 3 and node.args[3] is not None if has_bias: bias_node = self.get_node(node.args[3]) + assert bias_node is not None bias_tensor = get_parameter(bias_node, self.edge_program) - else: - bias_tensor = torch.zeros(normalized_shapes, dtype=torch.float32) - bias_node = torch.fx.Node( - node.graph, - node.name + "_runtime_bias", - "call_function", - exir_ops.edge.aten.tensor.default, - (), - {}, + bias_tensor_wrapper = self.define_tensor( + bias_node, + node, + bias_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, ) - if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): - quant_attrs = quant_attrs.copy() - quant_attrs[QCOM_ZERO_POINT] = 0 - bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs - bias_tensor_wrapper = self.define_tensor( - bias_node, - node, - bias_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) + layer_norm_input_tensors.append(bias_tensor_wrapper) epsilon = node.args[4] if len(node.args) > 4 else 1e-05 @@ -130,9 +120,7 @@ def define_node( QNN_OP_PACKAGE_NAME_QTI_AISW, OpLayerNorm.op_name, ) - layer_norm_op.AddInputTensors( - [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] - ) + layer_norm_op.AddInputTensors(layer_norm_input_tensors) layer_norm_op.AddOutputTensors([output_tensor_wrapper]) layer_norm_op.AddScalarParam( OpLayerNorm.param_epsilon, diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index acf4400f212..eb82b6335e8 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -29,9 +29,7 @@ def is_parameter( def get_parameter( node: torch.fx.Node, edge_program: torch.export.ExportedProgram -) -> Optional[torch.Tensor]: - if node is None: - return None +) -> torch.Tensor: param = None if is_param(edge_program, node): param = get_param(edge_program, node) @@ -39,10 +37,12 @@ def get_parameter( param = get_buffer(edge_program, node) if is_lifted_tensor_constant(edge_program, node): param = get_lifted_tensor_constant(edge_program, node) - if param is not None: - # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32) - assert isinstance(param, torch.Tensor), "Expect parameter to be tensor" - param = param.type(node.meta["val"].dtype) + assert param is not None, ( + f"Expect {node.name} to be parameter, buffer, or lifted tensor constant" + ) + # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32) + assert isinstance(param, torch.Tensor), "Expect parameter to be tensor" + param = param.type(node.meta["val"].dtype) return param diff --git a/backends/qualcomm/quantizer/rules.py b/backends/qualcomm/quantizer/rules.py index 77989f8e874..878acfea422 100644 --- a/backends/qualcomm/quantizer/rules.py +++ b/backends/qualcomm/quantizer/rules.py @@ -29,8 +29,6 @@ def _mark_nodes_as_annotated(nodes: List[Node]): for node in nodes: - if node is None: - continue if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() node.meta[Q_ANNOTATION_KEY]._annotated = True diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index d5cb42d72f2..9d0bd964d33 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1379,35 +1379,20 @@ def forward(self, x): class LayerNorm(torch.nn.Module): - def __init__(self, bias=True): + def __init__(self, elementwise_affine=True, bias=True): super().__init__() - self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6, bias=bias) + self.layer_norm = torch.nn.LayerNorm( + [768], + eps=1e-6, + elementwise_affine=elementwise_affine, + bias=bias, + ) self.linear = torch.nn.Linear(768, 196) def forward(self, x): return self.linear(self.layer_norm(x)) -class NativeLayerNorm(torch.nn.Module): - def __init__(self, affine=True): - super().__init__() - self.affine = affine - self.weight = torch.nn.Parameter(torch.ones(768)) - self.bias = torch.nn.Parameter(torch.zeros(768)) - self.normalized_shape = [768] - self.eps = 1e-6 - - def forward(self, x): - if self.affine: - return torch.native_layer_norm( - x, self.normalized_shape, self.weight, self.bias, self.eps - )[0] - else: - return torch.native_layer_norm( - x, self.normalized_shape, self.weight, self.bias, self.eps - )[0] - - class LayerNormAdd(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index ef1a61715e8..4ed496c07ad 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1378,14 +1378,11 @@ def test_qnn_backend_up_sampling_nearest_2d_with_size(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_layer_norm(self): - modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405 - sample_input = (torch.randn(196, 768),) - for i, module in enumerate(modules): - with self.subTest(i=i): - self.lower_module_and_test_output(module, sample_input) - - def test_qnn_backend_native_layer_norm(self): - modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405 + modules = [ + LayerNorm(), # noqa: F405 + LayerNorm(bias=False), # noqa: F405 + LayerNorm(elementwise_affine=False), # noqa: F405 + ] sample_input = (torch.randn(196, 768),) for i, module in enumerate(modules): with self.subTest(i=i): @@ -3811,15 +3808,11 @@ def test_qnn_backend_up_sampling_nearest_2d_with_size(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_layer_norm(self): - modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405 - sample_input = (torch.randn(196, 768),) - for i, module in enumerate(modules): - with self.subTest(i=i): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) - - def test_qnn_backend_native_layer_norm(self): - modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405 + modules = [ + LayerNorm(), # noqa: F405 + LayerNorm(bias=False), # noqa: F405 + LayerNorm(elementwise_affine=False), # noqa: F405 + ] sample_input = (torch.randn(196, 768),) for i, module in enumerate(modules): with self.subTest(i=i): From acef7be70ddf877acd5f264e593e789d6e928046 Mon Sep 17 00:00:00 2001 From: KevinJie Date: Tue, 28 Apr 2026 18:51:14 +0000 Subject: [PATCH 6/6] Update the formality, fix the linter issue --- backends/qualcomm/builders/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index eb82b6335e8..7e3b0aaa5f3 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -37,9 +37,9 @@ def get_parameter( param = get_buffer(edge_program, node) if is_lifted_tensor_constant(edge_program, node): param = get_lifted_tensor_constant(edge_program, node) - assert param is not None, ( - f"Expect {node.name} to be parameter, buffer, or lifted tensor constant" - ) + assert ( + param is not None + ), f"Expect {node.name} to be parameter, buffer, or lifted tensor constant" # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32) assert isinstance(param, torch.Tensor), "Expect parameter to be tensor" param = param.type(node.meta["val"].dtype)