Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
from executorch.backends.qualcomm.builders.utils import get_parameter
from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter
from executorch.backends.qualcomm.utils.constants import (
QCOM_DTYPE,
QCOM_ENCODING,
Expand Down Expand Up @@ -130,7 +130,7 @@ def _annotate_quant_attrs(
self._annotate_requant(n)
# With fold_quant enabled, check if the input of dq op is quantized param.
param = None
if n.target in dq_ops:
if n.target in dq_ops and is_parameter(n.args[0], self.edge_program):
param = get_parameter(n.args[0], self.edge_program)
if n.target not in q_ops and param is None:
continue
Expand Down
38 changes: 32 additions & 6 deletions backends/qualcomm/builders/op_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
from executorch.backends.qualcomm.utils.constants import (
QCOM_DATA,
QCOM_QUANT_ATTRS,
QCOM_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
Expand All @@ -31,6 +36,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
# args of node : ['input', 'normalized_shape', 'weight', 'bias', 'eps']
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
Expand All @@ -54,8 +60,26 @@ def define_node(
axis = [len(input_tensor.shape) - 1]
axis_shape = [len(axis)]

weight_node = self.get_node(node.args[2])
weight_tensor = get_parameter(weight_node, self.edge_program)
has_weight = len(node.args) > 2 and node.args[2] is not None
if has_weight:
weight_node = self.get_node(node.args[2])
assert weight_node is not None
weight_tensor = get_parameter(weight_node, self.edge_program)
else:
# elementwise_affine=False: use all-ones weight as identity
weight_tensor = torch.ones(normalized_shapes, dtype=torch.float32)
weight_node = torch.fx.Node(
node.graph,
node.name + "_runtime_weight",
"call_function",
exir_ops.edge.aten.tensor.default,
(),
{},
)
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
quant_attrs[QCOM_ZERO_POINT] = 0
weight_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
weight_tensor_wrapper = self.define_tensor(
weight_node,
node,
Expand All @@ -66,8 +90,10 @@ def define_node(

layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]

bias_node = self.get_node(node.args[3])
if bias_node is not None:
has_bias = len(node.args) > 3 and node.args[3] is not None
if has_bias:
bias_node = self.get_node(node.args[3])
assert bias_node is not None
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
Expand All @@ -78,7 +104,7 @@ def define_node(
)
layer_norm_input_tensors.append(bias_tensor_wrapper)

epsilon = node.args[4]
epsilon = node.args[4] if len(node.args) > 4 else 1e-05

output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
Expand Down
10 changes: 6 additions & 4 deletions backends/qualcomm/builders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ def get_parameter(
param = get_buffer(edge_program, node)
if is_lifted_tensor_constant(edge_program, node):
param = get_lifted_tensor_constant(edge_program, node)
if param is not None:
# update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32)
assert isinstance(param, torch.Tensor), "Expect parameter to be tensor"
param = param.type(node.meta["val"].dtype)
assert (
param is not None
), f"Expect {node.name} to be parameter, buffer, or lifted tensor constant"
# update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32)
assert isinstance(param, torch.Tensor), "Expect parameter to be tensor"
param = param.type(node.meta["val"].dtype)
Comment on lines +40 to +45
return param


Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def build_executorch_binary(
with open(pte_name, "wb") as file:
exec_prog_mgr.write_to_file(file)

print(f"Successfully generated {pte_name}.")
if qnn_config.compile_only:
sys.exit(0)

Expand Down
39 changes: 20 additions & 19 deletions backends/qualcomm/quantizer/annotators/htp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,16 +828,15 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:


@register_annotator(
[torch.ops.aten.layer_norm.default], QnnConstants.OpLayerNorm.op_name
[torch.ops.aten.layer_norm.default, torch.ops.aten.native_layer_norm.default],
QnnConstants.OpLayerNorm.op_name,
)
class LayerNorm(GeneralOpDef):
@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[2]
bias_node = None
if len(node.args) > 2:
bias_node = node.args[3]
weight_node = node.args[2] if len(node.args) > 2 else None
bias_node = node.args[3] if len(node.args) > 3 else None

if _is_annotated([node]):
return
Expand All @@ -848,20 +847,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
act_node,
input_act_qspec,
)
if input_act_qspec.dtype == torch.int32:
annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
nodes_to_mark_annotated = [node]
if isinstance(weight_node, Node):
if input_act_qspec.dtype == torch.int32:
annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated.append(weight_node)
if isinstance(bias_node, Node):
annotate_input_qspec_map(
node,
bias_node,
Expand Down
36 changes: 18 additions & 18 deletions backends/qualcomm/quantizer/annotators/lpai_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,8 @@ class LayerNorm(GeneralOpDef):
@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[2]
bias_node = None
if len(node.args) > 2:
bias_node = node.args[3]
weight_node = node.args[2] if len(node.args) > 2 else None
bias_node = node.args[3] if len(node.args) > 3 else None
Comment on lines 475 to +479
Comment on lines 475 to +479

if _is_annotated([node]):
return
Expand All @@ -489,20 +487,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
act_node,
input_act_qspec,
)
if input_act_qspec.dtype == torch.int32:
annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
nodes_to_mark_annotated = [node]
if isinstance(weight_node, Node):
if input_act_qspec.dtype == torch.int32:
annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated.append(weight_node)
if isinstance(bias_node, Node):
annotate_input_qspec_map(
node,
bias_node,
Expand Down
9 changes: 7 additions & 2 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,9 +1418,14 @@ def forward(self, x):


class LayerNorm(torch.nn.Module):
def __init__(self, bias=True):
def __init__(self, elementwise_affine=True, bias=True):
super().__init__()
self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6, bias=bias)
self.layer_norm = torch.nn.LayerNorm(
[768],
eps=1e-6,
elementwise_affine=elementwise_affine,
bias=bias,
)
self.linear = torch.nn.Linear(768, 196)

def forward(self, x):
Expand Down
12 changes: 10 additions & 2 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,11 @@ def test_qnn_backend_up_sampling_nearest_2d_with_size(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_layer_norm(self):
modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405
modules = [
LayerNorm(), # noqa: F405
LayerNorm(bias=False), # noqa: F405
LayerNorm(elementwise_affine=False), # noqa: F405
]
sample_input = (torch.randn(196, 768),)
for i, module in enumerate(modules):
with self.subTest(i=i):
Expand Down Expand Up @@ -3871,7 +3875,11 @@ def test_qnn_backend_up_sampling_nearest_2d_with_size(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_layer_norm(self):
modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405
modules = [
LayerNorm(), # noqa: F405
LayerNorm(bias=False), # noqa: F405
LayerNorm(elementwise_affine=False), # noqa: F405
]
sample_input = (torch.randn(196, 768),)
for i, module in enumerate(modules):
with self.subTest(i=i):
Expand Down
Loading