diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index 847dbb94b58..4f98e2851e9 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -7,6 +7,8 @@ from typing import Tuple +import pytest + import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -120,3 +122,118 @@ def test_softmax_vgf_quant(test_data): # TODO: MLETORCH-1136 Change args of run_method_and_compare_outputs of the vgf tests # pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() + + +# --------------------------------------------------------------------------- +# a16w8 (int16 IO + int8 weights) softmax FVP coverage. +# +# Sweeps a multi-head-attention-shaped softmax over a wide range of +# pre-softmax input magnitudes to surface int16 numerics issues in the +# lowered graph (e.g. the Ethos-U85 ReduceSum int16 silent-zero issue in the +# softmax decomposition, fixed by the follow-up Vela patch in this stack). +# --------------------------------------------------------------------------- + + +class MultiHeadAttentionSoftmax(torch.nn.Module): + """Generic multi-head-attention softmax: reshape -> softmax(dim=-1) -> flatten. + + H heads, M query tokens, W K/V window. Output shape: (N, T, H*M*W). + """ + + H = 4 + M = 1 + W = 16 + IN_FEATURES = H * M * W # 64 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n, t, _ = x.shape + x = x.reshape(n, t, self.H, self.M, self.W) + x = torch.softmax(x, dim=-1) + x = x.reshape(n, t, self.IN_FEATURES) + return x + + +# (input_low, input_high) per case. Keys are the parametrize ids. +# Range coverage spans realistic post-1/sqrt(d) attention logits (typically +# in [-10, +10]) plus a couple of wider buffer cases. atol below is sized +# at ~1.5x the observed FVP max-abs softmax error across the sweep at +# qtol=0, measured against the quantized reference. +mha_softmax_sweep = { + "range_neg0p01_to_0p01": (-0.01, 0.01), + "range_neg0p1_to_0p1": (-0.1, 0.1), + "range_neg1_to_1": (-1.0, 1.0), + "range_neg3_to_3": (-3.0, 3.0), + "range_neg5_to_5": (-5.0, 5.0), + "range_neg10_to_10": (-10.0, 10.0), + "range_neg30_to_30": (-30.0, 30.0), +} + +_MHA_ATOL = 0.003 + + +def _make_mha_softmax_inputs( + input_low: float, input_high: float, num_test: int = 8, seed: int = 42 +) -> Tuple[torch.Tensor]: + # Local Generator so this helper does not mutate the global RNG state + # and the test suite stays order-independent. + gen = torch.Generator().manual_seed(seed) + span = input_high - input_low + return ( + torch.rand( + num_test, + 1, + MultiHeadAttentionSoftmax.IN_FEATURES, + generator=gen, + ) + * span + + input_low, + ) + + +@common.parametrize("sweep_case", mha_softmax_sweep) +@common.XfailIfNoCorstone300 +def test_mha_softmax_a16w8_u55_INT(sweep_case: Tuple[float, float]) -> None: + input_low, input_high = sweep_case + pipeline = EthosU55PipelineINT[input_t1]( + MultiHeadAttentionSoftmax(), + _make_mha_softmax_inputs(input_low, input_high), + [], + exir_ops=[], + a16w8_quantization=True, + symmetric_io_quantization=True, + epsilon=2**-16, + atol=_MHA_ATOL, + ) + pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) + pipeline.run() + + +# All cases hit the Ethos-U85 int16 ReduceSum silent-zero issue inside the +# softmax decomposition. strict=False so the test target stays green both +# on stock Vela 5.0 (cases XFAIL) and once the upstream Vela fix lands +# (cases XPASS). +# Upstream report: +# https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/issues/23 +@common.parametrize("sweep_case", mha_softmax_sweep) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason=( + "Ethos-U85 int16 ReduceSum silent-zero in softmax decomposition: " + "https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/issues/23" + ), + strict=False, +) +def test_mha_softmax_a16w8_u85_INT(sweep_case: Tuple[float, float]) -> None: + input_low, input_high = sweep_case + pipeline = EthosU85PipelineINT[input_t1]( + MultiHeadAttentionSoftmax(), + _make_mha_softmax_inputs(input_low, input_high), + [], + exir_ops=[], + a16w8_quantization=True, + symmetric_io_quantization=True, + epsilon=2**-16, + atol=_MHA_ATOL, + ) + pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) + pipeline.run() diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 6a39d1fe5c1..054016a79f7 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -29,6 +29,7 @@ def define_arm_tests(): "ops/test_rsqrt.py", "ops/test_slice.py", "ops/test_sigmoid.py", + "ops/test_softmax.py", "ops/test_sub.py", "ops/test_sum.py", "ops/test_tanh.py",