From 82309e4180e2fd91557e6084614185df3a06c70a Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Tue, 12 May 2026 10:06:46 -0700 Subject: [PATCH] Fused quant hardswish kernel (#19488) Summary: Fused quant hardswish kernel with optional dequantize/quantize. Unary op that applies x * min(max(x+3, 0), 6) / 6. Supports per-tensor and per-channel quantization. Reviewed By: mvartani-meta Differential Revision: D103754780 --- backends/cadence/fused_quant/op_hardswish.cpp | 92 ++++ backends/cadence/fused_quant/op_hardswish.h | 37 ++ backends/cadence/fused_quant/targets.bzl | 12 + backends/cadence/fused_quant/tests/BUCK | 11 + .../fused_quant/tests/test_op_hardswish.cpp | 405 ++++++++++++++++++ 5 files changed, 557 insertions(+) create mode 100644 backends/cadence/fused_quant/op_hardswish.cpp create mode 100644 backends/cadence/fused_quant/op_hardswish.h create mode 100644 backends/cadence/fused_quant/tests/test_op_hardswish.cpp diff --git a/backends/cadence/fused_quant/op_hardswish.cpp b/backends/cadence/fused_quant/op_hardswish.cpp new file mode 100644 index 00000000000..0d653a1bfae --- /dev/null +++ b/backends/cadence/fused_quant/op_hardswish.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace cadence { +namespace fused_quant { +namespace native { + +using executorch::aten::optional; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +namespace { + +void hardswish_kernel(const float* inp, float* out, int64_t numel) { + for (int64_t i = 0; i < numel; ++i) { + float x = inp[i]; + out[i] = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f; + } +} + +} // namespace + +Tensor& hardswish_out( + KernelRuntimeContext& ctx, + const Tensor& inp, + const optional& inp_scale, + const optional& inp_zero_point, + ScalarType inp_dtype, + int64_t inp_quant_min, + int64_t inp_quant_max, + optional inp_axis, + const optional& out_scale, + const optional& out_zero_point, + ScalarType out_dtype, + int64_t out_quant_min, + int64_t out_quant_max, + optional out_axis, + Tensor& out) { + int64_t numel = inp.numel(); + + bool inp_quantized = inp_scale.has_value(); + bool out_quantized = out_scale.has_value(); + + std::vector inp_buf; + const float* const inp_float = [&]() -> const float* { + if (!inp_quantized) { + return inp.const_data_ptr(); + } + inp_buf.resize(numel); + QParams qp = extract_qparams( + inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp); + FUSED_QUANT_DTYPE_SWITCH( + inp.scalar_type(), + scalar_t, + dequantize_buffer( + inp.const_data_ptr(), inp_buf.data(), numel, qp);) + return inp_buf.data(); + }(); + + if (out_quantized) { + std::vector result_float(numel); + hardswish_kernel(inp_float, result_float.data(), numel); + + QParams qp = extract_qparams( + out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out); + FUSED_QUANT_DTYPE_SWITCH( + out.scalar_type(), + scalar_t, + quantize_buffer( + result_float.data(), out.mutable_data_ptr(), numel, qp);) + } else { + hardswish_kernel(inp_float, out.mutable_data_ptr(), numel); + } + + return out; +} + +} // namespace native +} // namespace fused_quant +} // namespace cadence diff --git a/backends/cadence/fused_quant/op_hardswish.h b/backends/cadence/fused_quant/op_hardswish.h new file mode 100644 index 00000000000..7cba5b07788 --- /dev/null +++ b/backends/cadence/fused_quant/op_hardswish.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace cadence { +namespace fused_quant { +namespace native { + +executorch::aten::Tensor& hardswish_out( + executorch::runtime::KernelRuntimeContext& ctx, + const executorch::aten::Tensor& inp, + const executorch::aten::optional& inp_scale, + const executorch::aten::optional& inp_zero_point, + executorch::aten::ScalarType inp_dtype, + int64_t inp_quant_min, + int64_t inp_quant_max, + executorch::aten::optional inp_axis, + const executorch::aten::optional& out_scale, + const executorch::aten::optional& out_zero_point, + executorch::aten::ScalarType out_dtype, + int64_t out_quant_min, + int64_t out_quant_max, + executorch::aten::optional out_axis, + executorch::aten::Tensor& out); + +} // namespace native +} // namespace fused_quant +} // namespace cadence diff --git a/backends/cadence/fused_quant/targets.bzl b/backends/cadence/fused_quant/targets.bzl index 0995f73e9e8..902d4d2727f 100644 --- a/backends/cadence/fused_quant/targets.bzl +++ b/backends/cadence/fused_quant/targets.bzl @@ -46,3 +46,15 @@ def define_common_targets(): ], visibility = ["PUBLIC"], ) + + runtime.cxx_library( + name = "op_hardswish", + srcs = ["op_hardswish.cpp"], + exported_headers = ["op_hardswish.h"], + platforms = CXX, + deps = [ + ":quant_utils", + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = ["PUBLIC"], + ) diff --git a/backends/cadence/fused_quant/tests/BUCK b/backends/cadence/fused_quant/tests/BUCK index f20c4472c57..c503049dc69 100644 --- a/backends/cadence/fused_quant/tests/BUCK +++ b/backends/cadence/fused_quant/tests/BUCK @@ -35,3 +35,14 @@ runtime.cxx_test( "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], ) + +runtime.cxx_test( + name = "test_op_hardswish", + srcs = ["test_op_hardswish.cpp"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/fused_quant:op_hardswish", + "//executorch/kernels/test:gtest_utils", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], +) diff --git a/backends/cadence/fused_quant/tests/test_op_hardswish.cpp b/backends/cadence/fused_quant/tests/test_op_hardswish.cpp new file mode 100644 index 00000000000..e92989c64d2 --- /dev/null +++ b/backends/cadence/fused_quant/tests/test_op_hardswish.cpp @@ -0,0 +1,405 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +using executorch::aten::optional; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::testing::TensorFactory; + +namespace { + +optional none_tensor() { + return optional(); +} + +optional none_axis() { + return optional(); +} + +} // namespace + +class FusedQuantHardswishTest : public OperatorTest {}; + +// All quantized: int8 → int8 (per-tensor) +TEST_F(FusedQuantHardswishTest, AllQuantizedPerTensor) { + TensorFactory tf_int8; + TensorFactory tf_float; + TensorFactory tf_long; + + const std::vector sizes{6}; + + Tensor inp = tf_int8.make(sizes, {-6, -3, 0, 3, 6, 10}); + + Tensor inp_scale = tf_float.make({1}, {1.0}); + Tensor inp_zp = tf_long.make({1}, {0}); + Tensor out_scale = tf_float.make({1}, {1.0}); + Tensor out_zp = tf_long.make({1}, {0}); + + Tensor out = tf_int8.zeros(sizes); + + // dequant inp: {-6, -3, 0, 3, 6, 10} + // hardswish(-6) = -6 * min(max(-3,0),6)/6 = 0 + // hardswish(-3) = -3 * min(max(0,0),6)/6 = 0 + // hardswish(0) = 0 * min(max(3,0),6)/6 = 0 + // hardswish(3) = 3 * min(max(6,0),6)/6 = 3 + // hardswish(6) = 6 * min(max(9,0),6)/6 = 6 + // hardswish(10) = 10 * min(max(13,0),6)/6 = 10 + // requant (scale=1.0, zp=0): {0, 0, 0, 3, 6, 10} + cadence::fused_quant::native::hardswish_out( + context_, + inp, + optional(inp_scale), + optional(inp_zp), + ScalarType::Float, + -128, + 127, + none_axis(), + optional(out_scale), + optional(out_zp), + ScalarType::Char, + -128, + 127, + none_axis(), + out); + + EXPECT_TENSOR_EQ(out, tf_int8.make(sizes, {0, 0, 0, 3, 6, 10})); +} + +// float → int8 +TEST_F(FusedQuantHardswishTest, FloatInputQuantizedOutput) { + TensorFactory tf_int8; + TensorFactory tf_float; + TensorFactory tf_long; + + const std::vector sizes{6}; + + Tensor inp = tf_float.make(sizes, {-6.0, -3.0, 0.0, 3.0, 6.0, 10.0}); + + Tensor out_scale = tf_float.make({1}, {1.0}); + Tensor out_zp = tf_long.make({1}, {0}); + + Tensor out = tf_int8.zeros(sizes); + + // hardswish: {0, 0, 0, 3, 6, 10} + // requant (scale=1.0, zp=0): {0, 0, 0, 3, 6, 10} + cadence::fused_quant::native::hardswish_out( + context_, + inp, + none_tensor(), + none_tensor(), + ScalarType::Float, + 0, + 0, + none_axis(), + optional(out_scale), + optional(out_zp), + ScalarType::Char, + -128, + 127, + none_axis(), + out); + + EXPECT_TENSOR_EQ(out, tf_int8.make(sizes, {0, 0, 0, 3, 6, 10})); +} + +// int8 → float +TEST_F(FusedQuantHardswishTest, QuantizedInputFloatOutput) { + TensorFactory tf_int8; + TensorFactory tf_float; + TensorFactory tf_long; + + const std::vector sizes{6}; + + Tensor inp = tf_int8.make(sizes, {-6, -3, 0, 3, 6, 10}); + + Tensor inp_scale = tf_float.make({1}, {1.0}); + Tensor inp_zp = tf_long.make({1}, {0}); + + Tensor out = tf_float.zeros(sizes); + + // dequant inp: {-6, -3, 0, 3, 6, 10} + // hardswish: {0.0, 0.0, 0.0, 3.0, 6.0, 10.0} + cadence::fused_quant::native::hardswish_out( + context_, + inp, + optional(inp_scale), + optional(inp_zp), + ScalarType::Float, + -128, + 127, + none_axis(), + none_tensor(), + none_tensor(), + ScalarType::Float, + 0, + 0, + none_axis(), + out); + + EXPECT_TENSOR_EQ(out, tf_float.make(sizes, {0.0, 0.0, 0.0, 3.0, 6.0, 10.0})); +} + +// Per-channel dequantization on input, per-tensor output +TEST_F(FusedQuantHardswishTest, PerChannelInput) { + TensorFactory tf_int8; + TensorFactory tf_float; + TensorFactory tf_long; + + // Shape [2, 3], axis=0 → 2 channels, axis_stride=3 + const std::vector sizes{2, 3}; + + Tensor inp = tf_int8.make(sizes, {-6, -3, 0, 3, 6, 10}); + + // Per-channel: channel 0 scale=1.0, channel 1 scale=0.5 + Tensor inp_scale = tf_float.make({2}, {1.0, 0.5}); + Tensor inp_zp = tf_long.make({2}, {0, 0}); + Tensor out_scale = tf_float.make({1}, {0.5}); + Tensor out_zp = tf_long.make({1}, {0}); + + Tensor out = tf_int8.zeros(sizes); + + // dequant channel 0 (scale=1.0): {-6, -3, 0} + // dequant channel 1 (scale=0.5): {1.5, 3.0, 5.0} + // hardswish(-6) = 0, hardswish(-3) = 0, hardswish(0) = 0 + // hardswish(1.5) = 1.5 * min(max(4.5,0),6)/6 = 1.5*4.5/6 = 1.125 + // hardswish(3.0) = 3 * min(max(6,0),6)/6 = 3*6/6 = 3.0 + // hardswish(5.0) = 5 * min(max(8,0),6)/6 = 5*6/6 = 5.0 + // requant (scale=0.5, zp=0): round(0/0.5)=0, 0, 0, + // round(1.125/0.5)=round(2.25)=2, round(3.0/0.5)=6, round(5.0/0.5)=10 + cadence::fused_quant::native::hardswish_out( + context_, + inp, + optional(inp_scale), + optional(inp_zp), + ScalarType::Float, + -128, + 127, + optional(0), + optional(out_scale), + optional(out_zp), + ScalarType::Char, + -128, + 127, + none_axis(), + out); + + EXPECT_TENSOR_EQ(out, tf_int8.make(sizes, {0, 0, 0, 2, 6, 10})); +} + +// Per-channel quantization on output +TEST_F(FusedQuantHardswishTest, PerChannelOutput) { + TensorFactory tf_int8; + TensorFactory tf_float; + TensorFactory tf_long; + + // Shape [2, 3], axis=0 → 2 channels + const std::vector sizes{2, 3}; + + Tensor inp = tf_float.make(sizes, {-6.0, 0.0, 3.0, 6.0, 10.0, 12.0}); + + // Per-channel output: channel 0 scale=1.0, channel 1 scale=0.5 + Tensor out_scale = tf_float.make({2}, {1.0, 0.5}); + Tensor out_zp = tf_long.make({2}, {0, 0}); + + Tensor out = tf_int8.zeros(sizes); + + // hardswish(-6) = 0, hardswish(0) = 0, hardswish(3) = 3 + // hardswish(6) = 6, hardswish(10) = 10, hardswish(12) = 12 + // requant channel 0 (scale=1.0): round(0/1)=0, round(0/1)=0, round(3/1)=3 + // requant channel 1 (scale=0.5): round(6/0.5)=12, round(10/0.5)=20, + // round(12/0.5)=24 + cadence::fused_quant::native::hardswish_out( + context_, + inp, + none_tensor(), + none_tensor(), + ScalarType::Float, + 0, + 0, + none_axis(), + optional(out_scale), + optional(out_zp), + ScalarType::Char, + -128, + 127, + optional(0), + out); + + EXPECT_TENSOR_EQ(out, tf_int8.make(sizes, {0, 0, 3, 12, 20, 24})); +} + +// Non-zero zero points +TEST_F(FusedQuantHardswishTest, NonZeroZeroPoint) { + TensorFactory tf_int8; + TensorFactory tf_float; + TensorFactory tf_long; + + const std::vector sizes{6}; + + Tensor inp = tf_int8.make(sizes, {-4, -1, 2, 5, 8, 12}); + + // scale=1.0, zp=2 → dequant: (v-2)*1.0 + Tensor inp_scale = tf_float.make({1}, {1.0}); + Tensor inp_zp = tf_long.make({1}, {2}); + // out scale=1.0, zp=1 → requant: round(f/1.0)+1 + Tensor out_scale = tf_float.make({1}, {1.0}); + Tensor out_zp = tf_long.make({1}, {1}); + + Tensor out = tf_int8.zeros(sizes); + + // dequant inp: {-6, -3, 0, 3, 6, 10} + // hardswish: {0, 0, 0, 3, 6, 10} + // requant (scale=1.0, zp=1): round(0/1)+1=1, 1, 1, + // round(3/1)+1=4, round(6/1)+1=7, round(10/1)+1=11 + cadence::fused_quant::native::hardswish_out( + context_, + inp, + optional(inp_scale), + optional(inp_zp), + ScalarType::Float, + -128, + 127, + none_axis(), + optional(out_scale), + optional(out_zp), + ScalarType::Char, + -128, + 127, + none_axis(), + out); + + EXPECT_TENSOR_EQ(out, tf_int8.make(sizes, {1, 1, 1, 4, 7, 11})); +} + +// All values <= -3 should give 0 (negative saturation region) +TEST_F(FusedQuantHardswishTest, NegativeRegion) { + TensorFactory tf_int8; + TensorFactory tf_float; + TensorFactory tf_long; + + const std::vector sizes{4}; + + Tensor inp = tf_float.make(sizes, {-10.0, -6.0, -4.0, -3.0}); + + Tensor out_scale = tf_float.make({1}, {1.0}); + Tensor out_zp = tf_long.make({1}, {0}); + + Tensor out = tf_int8.zeros(sizes); + + // hardswish(-10) = -10 * min(max(-7,0),6)/6 = 0 + // hardswish(-6) = -6 * min(max(-3,0),6)/6 = 0 + // hardswish(-4) = -4 * min(max(-1,0),6)/6 = 0 + // hardswish(-3) = -3 * min(max(0,0),6)/6 = 0 + // requant (scale=1.0, zp=0): {0, 0, 0, 0} + cadence::fused_quant::native::hardswish_out( + context_, + inp, + none_tensor(), + none_tensor(), + ScalarType::Float, + 0, + 0, + none_axis(), + optional(out_scale), + optional(out_zp), + ScalarType::Char, + -128, + 127, + none_axis(), + out); + + EXPECT_TENSOR_EQ(out, tf_int8.make(sizes, {0, 0, 0, 0})); +} + +// All values >= 3 should pass through unchanged (linear region) +TEST_F(FusedQuantHardswishTest, LinearRegion) { + TensorFactory tf_float; + + const std::vector sizes{4}; + + Tensor inp = tf_float.make(sizes, {3.0, 4.0, 6.0, 10.0}); + + Tensor out = tf_float.zeros(sizes); + + // hardswish(3) = 3 * min(max(6,0),6)/6 = 3 + // hardswish(4) = 4 * min(max(7,0),6)/6 = 4 + // hardswish(6) = 6 * min(max(9,0),6)/6 = 6 + // hardswish(10) = 10 * min(max(13,0),6)/6 = 10 + cadence::fused_quant::native::hardswish_out( + context_, + inp, + none_tensor(), + none_tensor(), + ScalarType::Float, + 0, + 0, + none_axis(), + none_tensor(), + none_tensor(), + ScalarType::Float, + 0, + 0, + none_axis(), + out); + + EXPECT_TENSOR_EQ(out, tf_float.make(sizes, {3.0, 4.0, 6.0, 10.0})); +} + +// Values between -3 and 3 use the piecewise formula +TEST_F(FusedQuantHardswishTest, TransitionRegion) { + TensorFactory tf_int8; + TensorFactory tf_float; + TensorFactory tf_long; + + const std::vector sizes{5}; + + // int8 input with scale=0.5, zp=0 → float {-3.0, -1.5, 0.0, 1.5, 3.0} + Tensor inp = tf_int8.make(sizes, {-6, -3, 0, 3, 6}); + + Tensor inp_scale = tf_float.make({1}, {0.5}); + Tensor inp_zp = tf_long.make({1}, {0}); + Tensor out_scale = tf_float.make({1}, {0.125}); + Tensor out_zp = tf_long.make({1}, {0}); + + Tensor out = tf_int8.zeros(sizes); + + // dequant: {-3.0, -1.5, 0.0, 1.5, 3.0} + // hardswish(-3.0) = -3*min(max(0,0),6)/6 = 0 + // hardswish(-1.5) = -1.5*min(max(1.5,0),6)/6 = -1.5*1.5/6 = -0.375 + // hardswish(0) = 0*min(max(3,0),6)/6 = 0 + // hardswish(1.5) = 1.5*min(max(4.5,0),6)/6 = 1.5*4.5/6 = 1.125 + // hardswish(3.0) = 3*min(max(6,0),6)/6 = 3*6/6 = 3.0 + // requant (scale=0.125, zp=0): round(0/0.125)=0, round(-0.375/0.125)=-3, + // round(0/0.125)=0, round(1.125/0.125)=9, round(3.0/0.125)=24 + cadence::fused_quant::native::hardswish_out( + context_, + inp, + optional(inp_scale), + optional(inp_zp), + ScalarType::Float, + -128, + 127, + none_axis(), + optional(out_scale), + optional(out_zp), + ScalarType::Char, + -128, + 127, + none_axis(), + out); + + EXPECT_TENSOR_EQ(out, tf_int8.make(sizes, {0, -3, 0, 9, 24})); +}