Skip to content

[AArch64][TTI] Allow mixed-extension partial reductions with +dotprod #199762

Open
fhahn wants to merge 1 commit into
llvm:mainfrom
fhahn:aarch64-usdot-cost-plain-dotprod
Open

[AArch64][TTI] Allow mixed-extension partial reductions with +dotprod #199762
fhahn wants to merge 1 commit into
llvm:mainfrom
fhahn:aarch64-usdot-cost-plain-dotprod

Conversation

@fhahn
Copy link
Copy Markdown
Contributor

@fhahn fhahn commented May 26, 2026

With the backend now lowering SUMLA via two udot products on targets that have +dotprod (#199761), lower the cost on targets without +i8mm.

…rod.

With the backend now lowering SUMLA via two udot products on targets that
have +dotprod, lower the cost on targets without +i8mm.
@llvmorg-github-actions
Copy link
Copy Markdown

llvmorg-github-actions Bot commented May 26, 2026

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-aarch64

Author: Florian Hahn (fhahn)

Changes

With the backend now lowering SUMLA via two udot products on targets that have +dotprod (#199761), lower the cost on targets without +i8mm.


Full diff: https://github.com/llvm/llvm-project/pull/199762.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+7-2)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll (+36-36)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 4e6059412662e..259520600da1f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -6018,8 +6018,9 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
     return Invalid;
 
   bool IsUSDot = OpBExtend != TTI::PR_None && OpAExtend != OpBExtend;
-  if (IsUSDot && !ST->hasMatMulInt8())
-    // FIXME: Remove this early bailout in favour of expand cost.
+  // USDot is natively supported with +i8mm. With plain +dotprod, SUMLA is
+  // lowered to two udots plus an eor and a sub.
+  if (IsUSDot && !ST->hasMatMulInt8() && !ST->hasDotProd())
     return Invalid;
 
   unsigned Ratio =
@@ -6075,6 +6076,10 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
     // i8 -> i32 usdot requires +i8mm
     if (IsUSDot && IsSupported(ST->hasMatMulInt8(), ST->hasMatMulInt8()))
       return Cost + INegCost;
+    // Without +i8mm, lower SUMLA via two udots plus an eor and a sub on plain
+    // +dotprod targets. Charge an extra factor for the expansion.
+    if (IsUSDot && IsSupported(false, ST->hasDotProd()))
+      return Cost * 3 + INegCost;
   }
 
   if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
index 98fdc0647f2fb..f6a632b7b33c0 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
@@ -52,30 +52,30 @@ define i32 @sudot(ptr %a, ptr %b) #0 {
 ; CHECK-NOI8MM-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-NOI8MM:       vector.body:
 ; CHECK-NOI8MM-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP11:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NOI8MM-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP1:%.*]] = getelementptr i8, ptr [[TMP6]], i64 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP1]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = getelementptr i8, ptr [[TMP13]], i64 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP13]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
 ; CHECK-NOI8MM-NEXT:    [[TMP14:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = mul <16 x i32> [[TMP14]], [[TMP2]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP8]])
+; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = mul <16 x i32> [[TMP7]], [[TMP3]]
-; CHECK-NOI8MM-NEXT:    [[TMP10]] = add <16 x i32> [[TMP8]], [[VEC_PHI]]
-; CHECK-NOI8MM-NEXT:    [[TMP11]] = add <16 x i32> [[TMP9]], [[VEC_PHI1]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP9]])
 ; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NOI8MM-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NOI8MM-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK-NOI8MM:       middle.block:
-; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <16 x i32> [[TMP11]], [[TMP10]]
-; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[BIN_RDX]])
+; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <4 x i32> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]]
+; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[BIN_RDX]])
 ; CHECK-NOI8MM-NEXT:    br label [[FOR_EXIT:%.*]]
 ; CHECK-NOI8MM:       for.exit:
 ; CHECK-NOI8MM-NEXT:    ret i32 [[TMP15]]
@@ -149,30 +149,30 @@ define i32 @usdot(ptr %a, ptr %b) #0 {
 ; CHECK-NOI8MM-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-NOI8MM:       vector.body:
 ; CHECK-NOI8MM-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP11:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NOI8MM-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP1:%.*]] = getelementptr i8, ptr [[TMP6]], i64 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP1]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP2:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = getelementptr i8, ptr [[TMP13]], i64 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP13]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
 ; CHECK-NOI8MM-NEXT:    [[TMP14:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP2:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = mul <16 x i32> [[TMP14]], [[TMP2]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP8]])
+; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = mul <16 x i32> [[TMP7]], [[TMP3]]
-; CHECK-NOI8MM-NEXT:    [[TMP10]] = add <16 x i32> [[TMP8]], [[VEC_PHI]]
-; CHECK-NOI8MM-NEXT:    [[TMP11]] = add <16 x i32> [[TMP9]], [[VEC_PHI1]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP9]])
 ; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NOI8MM-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NOI8MM-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
 ; CHECK-NOI8MM:       middle.block:
-; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <16 x i32> [[TMP11]], [[TMP10]]
-; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[BIN_RDX]])
+; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <4 x i32> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]]
+; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[BIN_RDX]])
 ; CHECK-NOI8MM-NEXT:    br label [[FOR_EXIT:%.*]]
 ; CHECK-NOI8MM:       for.exit:
 ; CHECK-NOI8MM-NEXT:    ret i32 [[TMP15]]
@@ -244,30 +244,30 @@ define i32 @sudot_neon(ptr %a, ptr %b) #1 {
 ; CHECK-NOI8MM-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-NOI8MM:       vector.body:
 ; CHECK-NOI8MM-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NOI8MM-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP2:%.*]] = getelementptr i8, ptr [[TMP0]], i64 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP0]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP5]], i64 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = mul <16 x i32> [[TMP8]], [[TMP3]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
+; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
-; CHECK-NOI8MM-NEXT:    [[TMP12]] = add <16 x i32> [[TMP10]], [[VEC_PHI]]
-; CHECK-NOI8MM-NEXT:    [[TMP13]] = add <16 x i32> [[TMP11]], [[VEC_PHI1]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP11]])
 ; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NOI8MM-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NOI8MM-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
 ; CHECK-NOI8MM:       middle.block:
-; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <16 x i32> [[TMP13]], [[TMP12]]
-; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[BIN_RDX]])
+; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <4 x i32> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]]
+; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[BIN_RDX]])
 ; CHECK-NOI8MM-NEXT:    br label [[FOR_EXIT:%.*]]
 ; CHECK-NOI8MM:       for.exit:
 ; CHECK-NOI8MM-NEXT:    ret i32 [[TMP15]]
@@ -339,30 +339,30 @@ define i32 @usdot_neon(ptr %a, ptr %b) #1 {
 ; CHECK-NOI8MM-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-NOI8MM:       vector.body:
 ; CHECK-NOI8MM-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NOI8MM-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP2:%.*]] = getelementptr i8, ptr [[TMP0]], i64 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP0]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP5]], i64 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = mul <16 x i32> [[TMP8]], [[TMP3]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
+; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
-; CHECK-NOI8MM-NEXT:    [[TMP12]] = add <16 x i32> [[TMP10]], [[VEC_PHI]]
-; CHECK-NOI8MM-NEXT:    [[TMP13]] = add <16 x i32> [[TMP11]], [[VEC_PHI1]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP11]])
 ; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NOI8MM-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NOI8MM-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
 ; CHECK-NOI8MM:       middle.block:
-; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <16 x i32> [[TMP13]], [[TMP12]]
-; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[BIN_RDX]])
+; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <4 x i32> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]]
+; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[BIN_RDX]])
 ; CHECK-NOI8MM-NEXT:    br label [[FOR_EXIT:%.*]]
 ; CHECK-NOI8MM:       for.exit:
 ; CHECK-NOI8MM-NEXT:    ret i32 [[TMP15]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant