Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,31 +1,59 @@
# Copyright 2025 NXP
# Copyright 2025-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging

import executorch.backends.nxp.backend.ir.lib.tflite.Padding as tflPadding
import torch

from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT
from executorch.backends.nxp.backend.ir.converter.conversion import common
from executorch.backends.nxp.backend.ir.converter.node_converter import (
CustomDelegationOptions,
NodeConverter,
)
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
average_pool_2d_options,
)
from torch import Size

from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from torch.fx import Node
from torch.nn import Parameter

KernelSize = tuple[int, int]
Stride = tuple[int, int]


class AdaptiveAvgPool2dConverter(NodeConverter):

@staticmethod
def _get_equivalent_avg_pool_parameters(node: Node) -> tuple[KernelSize, Stride]:
input_size = node.args[0].meta["val"].shape[2:] # Spatial dims from NCHW shape.
Comment thread
MartinPavella marked this conversation as resolved.
output_size = node.args[1]
stride = (input_size[0] // output_size[0], input_size[1] // output_size[1])
kernel_size = (
input_size[0] - (output_size[0] - 1) * stride[0],
input_size[1] - (output_size[1] - 1) * stride[1],
)

return kernel_size, stride

@staticmethod
def _is_supported_in_IR(
node: Node,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move this check to is_supported_on_target, since it's not an IR's problem we did not determine node format yet.
It also makes more sense because _get_equivalent_avg_pool_parameters is also used in is_supported_on_target and the code concerning node formats would be in one place.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check is there to make sure our NodeFormatInference handles the operator as we expected. It is there to make sure that the principles on which EdgeProgramToIRConverter is built are obeyed. It has nothing to do with the RT700 platform and it's Neutron parameters. It is strictly related to conversion to the IR. Therefore it doesn't make sense for it to be in is_supported_on_target. I agree that including it in is_supported_in_IR is not ideal (though it's not a big stretch).

The PR which originally added the adaptive avgpool2d to our backend didn't update the format inference, so I added this check. I believe we should add it to all node converters of operators which require channels last format in NeutronIR.

Let's leave it as is in this PR. I have created a ticket to address this: https://jira.sw.nxp.com/browse/EIEX-913
I think I'll get on it soon.

format_ := node.meta.get(NXP_NODE_FORMAT)
) is None or not format_.is_channels_first():
logging.warning(
"NXP backend: `adaptive_avg_pool_2d` doesn't have the required input format for delegation. "
"Please run `NodeFormatInference.identify_node_formats()` during lowering or report this issue."
)
return False

input_size = node.args[0].meta["val"].shape
output_size = node.args[1]

Expand All @@ -39,30 +67,53 @@ def _is_supported_in_IR(

return True

# noinspection PyMethodMayBeStatic
def _convert_adaptive_avg_pool_2d(
self, input_size: Size, output_size: list[int], t_op: tflite_model.Operator
):
t_op.builtin_options = average_pool_2d_options.AveragePool2D()
stride = [input_size[-2] // output_size[-2], input_size[-1] // output_size[-1]]
common.assign_2d_strides(t_op.builtin_options, stride)
t_op.builtin_options.filter_h = (
input_size[-2] - (output_size[-2] - 1) * stride[-2]
)
t_op.builtin_options.filter_w = (
input_size[-1] - (output_size[-1] - 1) * stride[-1]
@staticmethod
def _is_supported_on_target(
node: Node,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
kernel_size, stride = (
AdaptiveAvgPool2dConverter._get_equivalent_avg_pool_parameters(node)
)
t_op.builtin_options.padding = tflPadding.Padding.VALID

# AdaptiveAvgPool2d Node format: (Tensor self, SymInt[2] output_size)
if custom_delegation_options.use_new_flow_neutron_c:
# Requirements specified by the new Neutron flow documentation.

if not NodeConverter.uses_quantization_type_for_io(
node,
supported_types=[torch.int8, torch.uint8],
input_indices=[0],
output_indices=[0],
):
return False

if any(k > 4096 for k in kernel_size):
return False

if any(s > 4096 for s in stride):
return False

return True

def convert(self, node: Node):
"""Convert '_adaptive_avg_pool2d' operator to TFLite 'AveragePool2D'."""
"""Convert the '_adaptive_avg_pool2d' operator to NeutronIR 'AveragePool2D'.
The ExecuTorch schema is:
_adaptive_avg_pool2d(
Tensor self,
SymInt[2] output_size
) -> Tensor
"""
self.assert_convertible(node)

input_size = node.args[0].meta["val"].shape
output_size = node.args[1]

t_op = self._create_tflite_op_with_io_tensors(node)
t_op.builtin_options = average_pool_2d_options.AveragePool2D()

kernel_size, stride = self._get_equivalent_avg_pool_parameters(node)

common.assign_2d_strides(t_op.builtin_options, stride)
t_op.builtin_options.filter_h, t_op.builtin_options.filter_w = kernel_size
t_op.builtin_options.padding = tflPadding.Padding.VALID

self._convert_adaptive_avg_pool_2d(input_size, output_size, t_op)
self.builder.append_operators([t_op])
2 changes: 2 additions & 0 deletions backends/nxp/backend/node_format_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class NodeFormatInference:
# The op in the dictionary is mapped to a dictionary, which holds indices to input nodes
# that are always channels first.
ops_with_channels_first_nodes = {
exir_ops.edge.aten._adaptive_avg_pool2d.default: {"inputs": [0]},
torch.ops.aten.adaptive_avg_pool2d.default: {"inputs": [0]},
exir_ops.edge.aten.avg_pool2d.default: {"inputs": [0]},
exir_ops.edge.aten.convolution.default: {"inputs": [0, 1]},
exir_ops.edge.aten.max_pool2d_with_indices.default: {"inputs": [0]},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 NXP
# Copyright 2025-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
Comment thread
MartinPavella marked this conversation as resolved.
# LICENSE file in the root directory of this source tree.
Expand All @@ -10,15 +10,29 @@
from executorch.backends.nxp.backend.edge_program_converter import (
EdgeProgramToIRConverter,
)
from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
from executorch.backends.nxp.tests.executors import (
convert_run_compare,
graph_contains_any_of_ops,
ToChannelFirstPreprocess,
ToChannelLastPreprocess,
)
from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier
from executorch.backends.nxp.tests.model_output_comparator import (
AllCloseOutputComparator,
)
from executorch.backends.nxp.tests.models import (
AdaptiveAvgPool2dConvMeanDimModule,
AdaptiveAvgPool2dConvModule,
AdaptiveAvgPool2dModule,
)

from executorch.backends.nxp.tests.nsys_testing import lower_run_compare

from executorch.backends.nxp.tests.ops_aliases import (
AdaptiveAvgPool2D,
ExecutorchDelegateCall,
)
from torch.export import ExportedProgram
from executorch.backends.nxp.tests.use_qat import * # noqa F403
Expand Down Expand Up @@ -151,3 +165,104 @@ def test_adaptive_avg_pool_2d_mean_dim_quant_conversion(mocker, use_qat):
tflite_output_preprocess=ToChannelFirstPreprocess(),
input_data=input_data,
)


class TestAdaptiveAvgPool2DNewNeutronFlow:
def test__basic_nsys_inference(self, mocker, use_qat):
input_shape = (2, 3, 4, 6)
output_size = (2, 3)
model = AdaptiveAvgPool2dModule(output_size)
graph_verifier = DetailedGraphVerifier(
mocker,
expected_delegated_ops={AdaptiveAvgPool2D: 1},
expected_non_delegated_ops={},
)

output_comparator = AllCloseOutputComparator(
7.8e-3
) # Accept small error due to Neutron bug (AIR-14585).

lower_run_compare(
Comment thread
MartinPavella marked this conversation as resolved.
model,
input_shape,
graph_verifier,
RandomDatasetCreator(low=-1, high=1),
output_comparator=output_comparator,
use_qat=use_qat,
use_new_flow_neutron_c=True,
)

@pytest.mark.xfail(
strict=True,
reason="Known Neutron bad compute issue. Will be fixed in Neutron SW 3.1.2.",
)
def test__know_neutron_issue(self, mocker):
input_shape = (2, 3, 10, 15)
output_size = (5, 5)
model = AdaptiveAvgPool2dModule(output_size)
graph_verifier = DetailedGraphVerifier(
mocker,
expected_delegated_ops={AdaptiveAvgPool2D: 1},
expected_non_delegated_ops={},
)

# Use high tolerance so we notice when the issue is fixed.
output_comparator = AllCloseOutputComparator(7.8e-3)

lower_run_compare(
model,
input_shape,
graph_verifier,
RandomDatasetCreator(low=-1, high=1),
output_comparator=output_comparator,
use_new_flow_neutron_c=True,
)

def test__kernel_size_and_stride_limit(self, mocker):
input_shape = (1, 3, 4, 4096) # input_size = (1, 4096)
output_size = (
2,
1,
) # If we reduced both dims to 1, ExecuTorch would replace the op with mean.
# stride = input_size // output_size = 4096 / 1 = 4096
# kernel_size = input_size - (output_size - 1) * stride = 4096 - 0 * 4096 = 4096

model = AdaptiveAvgPool2dModule(output_size)
graph_verifier = DetailedGraphVerifier(
mocker,
expected_delegated_ops={AdaptiveAvgPool2D: 1},
expected_non_delegated_ops={},
)

output_comparator = AllCloseOutputComparator(
7.9e-3
) # Accept small error due to Neutron bug (AIR-14585).

lower_run_compare(
model,
input_shape,
graph_verifier,
RandomDatasetCreator(low=-1, high=1),
output_comparator=output_comparator,
use_new_flow_neutron_c=True,
)

def test__kernel_size_and_stride_limit_exceeded(self):
input_shape = (1, 3, 4, 4097) # input_size = (1, 4097)
output_size = (
2,
1,
) # If we reduced both dims to 1, ExecuTorch would replace the op with mean.
# stride = input_size // output_size = 4097 / 1 = 4097
# kernel_size = input_size - (output_size - 1) * stride = 4097 - 0 * 4097 = 4097

model = AdaptiveAvgPool2dModule(output_size)
delegated_ep = to_quantized_edge_program(
model, input_shape, use_new_flow_neutron_c=True
).exported_program()

# Make sure the `adaptive_avg_pool2d` was NOT delegated.
assert not graph_contains_any_of_ops(
delegated_ep.graph, [ExecutorchDelegateCall]
)
assert graph_contains_any_of_ops(delegated_ep.graph, [AdaptiveAvgPool2D])
8 changes: 7 additions & 1 deletion backends/nxp/tests/model_output_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ def compare_sample(self, sample_dir, cpu_output_tensors, npu_output_tensors):
assert np.any(
cpu_tensor
), "Output tensor contains only zeros. This is suspicious."
assert np.allclose(cpu_tensor, npu_tensor, atol=self.atol)
all_close = np.allclose(cpu_tensor, npu_tensor, atol=self.atol)
Comment thread
MartinPavella marked this conversation as resolved.
if not all_close:
max_diff = np.abs(cpu_tensor - npu_tensor).max()
print(
f"NPU output doesn't match reference. Maximum absolute difference: {max_diff}"
)
assert all_close


def _default_postprocess_fn(outputs: np.ndarray, _: str):
Expand Down
1 change: 1 addition & 0 deletions backends/nxp/tests/ops_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from executorch.exir.dialects._ops import ops as exir_ops

Abs = exir_ops.edge.aten.abs.default
AdaptiveAvgPool2D = exir_ops.edge.aten._adaptive_avg_pool2d.default
Comment thread
MartinPavella marked this conversation as resolved.
AvgPool2D = exir_ops.edge.aten.avg_pool2d.default
Bmm = exir_ops.edge.aten.bmm.default
Convolution = exir_ops.edge.aten.convolution.default
Expand Down
Loading