diff --git a/source/tnn/device/arm/acc/arm_swish_layer_acc.cc b/source/tnn/device/arm/acc/arm_swish_layer_acc.cc index bf9366989..e9ddf6673 100644 --- a/source/tnn/device/arm/acc/arm_swish_layer_acc.cc +++ b/source/tnn/device/arm/acc/arm_swish_layer_acc.cc @@ -20,6 +20,9 @@ typedef struct arm_swish_operator : arm_unary_operator { virtual Float4 operator()(const Float4 &v) { return v * Float4::sigmoid(v); } + virtual Float4 fast_op(const Float4& v) { + return v * Float4::fast_sigmoid(v); + } } ARM_SWISH_OP; DECLARE_ARM_UNARY_ACC_FP16(Swish, ARM_SWISH_OP); diff --git a/source/tnn/device/x86/acc/x86_swish_layer_acc.cc b/source/tnn/device/x86/acc/x86_swish_layer_acc.cc index a3d8bc256..298d0bb76 100644 --- a/source/tnn/device/x86/acc/x86_swish_layer_acc.cc +++ b/source/tnn/device/x86/acc/x86_swish_layer_acc.cc @@ -12,18 +12,27 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "tnn/device/x86/acc/x86_unary_layer_acc.h" +#include "tnn/device/x86/acc/x86_unary2_layer_acc.h" #include namespace TNN_NS { -typedef struct x86_swish_operator : x86_unary_operator { +typedef struct x86_swish_operator : x86_unary2_operator { virtual float operator()(const float in) { return in * (1.0f / (1.0f + exp(-in))); } -} X86_SWISH_OP; -DECLARE_X86_UNARY_ACC(Swish, X86_SWISH_OP); + virtual Float4 operator()(const Float4 &v) { + return v * Float4::sigmoid(v); + } + + virtual Float8 operator()(const Float8 &v) { + return v * Float8::sigmoid(v); + } +} X86_SWISH_OP; +X86_REGISTER_UNARY2_KERNEL(LAYER_SWISH, avx2, unary2_kernel_avx); +X86_REGISTER_UNARY2_KERNEL(LAYER_SWISH, sse42, unary2_kernel_sse); +DECLARE_X86_UNARY2_ACC(Swish, LAYER_SWISH); REGISTER_X86_ACC(Swish, LAYER_SWISH); } // namespace TNN_NS \ No newline at end of file