[AArch64] Lower SUMLA via two udot products on plain dotprod targets.#199761
[AArch64] Lower SUMLA via two udot products on plain dotprod targets.#199761fhahn wants to merge 3 commits into
Conversation
Lower PARTIAL_REDUCE_SUMLA on targets without i8mm using udot via
sum(mul (zext %a, sext %b)) =
sum(mul zext %a, zext (xor %b, 0x80)) -
sum(mul 128, zext %a)
using sext(s) = zext(xor(s, 0x80)) - 128
(https://alive2.llvm.org/ce/z/8G8rer)
Alive2 proofs:
* mul (zext, sext) https://alive2.llvm.org/ce/z/8Av7rG
* mul (sext, zext) https://alive2.llvm.org/ce/z/g-gJGq
|
@llvm/pr-subscribers-backend-aarch64 Author: Florian Hahn (fhahn) ChangesLower PARTIAL_REDUCE_SUMLA on targets without i8mm using udot via sum(mul (zext %a, sext %b)) = using sext(s) = zext(xor(s, 0x80)) - 128 Alive2 proofs:
Patch is 31.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/199761.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 08f3d4e0d30ac..e5798ea8fa96a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1507,14 +1507,26 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v16i8, Custom);
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
+ // v2i64/v16i8 SUMLA always reduces to v4i32 SUMLA via
+ // LowerPARTIAL_REDUCE_MLA, regardless of i8mm; v2i32/v16i8 SUMLA goes
+ // through the same widen-to-v4i32 path as the SMLA/UMLA cases above.
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i64,
+ MVT::v16i8, Custom);
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i32,
+ MVT::v16i8, Custom);
+
if (Subtarget->hasMatMulInt8()) {
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v4i32,
MVT::v16i8, Legal);
- setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i64,
- MVT::v16i8, Custom);
-
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i32,
MVT::v8i8, Legal);
+ } else {
+ // Native dotprod without i8mm: lower SUMLA to two UDOT products in
+ // LowerPARTIAL_REDUCE_MLA.
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v4i32,
+ MVT::v16i8, Custom);
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i32,
+ MVT::v8i8, Custom);
}
}
@@ -33114,14 +33126,12 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
return Scatter;
}
-/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
-/// of (nx)v2i64/(nx)v16i8, we cannot directly lower it to a (u|s)dot. We can
-/// however still make use of the dot product instruction by instead
-/// accumulating over two steps: (nx)v16i8 -> (nx)v4i32 -> (nx)v2i64.
-/// If available, make use of the (U|S)ADDW(B|T) instructions, otherwise
-/// the following pattern is emitted:
-/// add(add(Acc, ext(EXTRACT_SUBVECTOR(N, 0)), ext(EXTRACT_SUBVECTOR(N,
-/// NTy/2))))
+/// Lower a PARTIAL_REDUCE_MLA node. Three cases are handled:
+/// 1. (v2i32, v16i8): widen Acc to v4i32 and fold the high half with ADDP.
+/// 2. (nx)v2i64/(nx)v16i8: accumulate in two steps via v4i32, using
+/// (U|S)ADDW(B|T) when available, otherwise add(add(Acc, ext(lo), ext(hi))).
+/// 3. SUMLA on (v4i32, v16i8) or (v2i32, v8i8) without +i8mm: rewrite as two
+/// UDOTs using sext(s) = zext(s ^ 0x80) - 0x80.
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
@@ -33145,6 +33155,27 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
return DAG.getExtractSubvector(DL, MVT::v2i32, Reduced, 0);
}
+ // Lower PARTIAL_REDUCE_SUMLA on targets without +i8mm using udot via
+ // sum(zext(RHS) * sext(LHS)) =
+ // sum(zext(RHS) * zext(LHS ^ 0x80)) - sum(zext(RHS) * 0x80)
+ // using sext(s) = zext(s ^ 0x80) - 0x80. LHS=signed, RHS=unsigned.
+ // The (v2i64, v16i8) case is handled by the v4i32 reduction below, which
+ // recursively re-enters this path.
+ if (Op.getOpcode() == ISD::PARTIAL_REDUCE_SUMLA &&
+ ((ResultVT == MVT::v4i32 && OpVT == MVT::v16i8) ||
+ (ResultVT == MVT::v2i32 && OpVT == MVT::v8i8))) {
+ assert(!Subtarget->hasMatMulInt8() && Subtarget->hasDotProd() &&
+ "Custom SUMLA lowering only registered for plain dotprod targets");
+ SDValue SignFlipMask = DAG.getConstant(0x80, DL, OpVT);
+ SDValue BiasedLHS = DAG.getNode(ISD::XOR, DL, OpVT, LHS, SignFlipMask);
+ SDValue BiasedDot = DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL, ResultVT, Acc,
+ BiasedLHS, RHS);
+ SDValue BiasCorrection =
+ DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL, ResultVT,
+ DAG.getConstant(0, DL, ResultVT), SignFlipMask, RHS);
+ return DAG.getNode(ISD::SUB, DL, ResultVT, BiasedDot, BiasCorrection);
+ }
+
bool ConvertToScalable =
ResultVT.isFixedLengthVector() &&
useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 709d3d387d3a1..b5801f8f48057 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -211,14 +211,12 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
;
; CHECK-DOT-LABEL: usdot:
; CHECK-DOT: // %bb.0:
-; CHECK-DOT-NEXT: ushll v3.8h, v1.8b, #0
-; CHECK-DOT-NEXT: sshll v4.8h, v2.8b, #0
-; CHECK-DOT-NEXT: ushll2 v1.8h, v1.16b, #0
-; CHECK-DOT-NEXT: sshll2 v2.8h, v2.16b, #0
-; CHECK-DOT-NEXT: smlal v0.4s, v4.4h, v3.4h
-; CHECK-DOT-NEXT: smlal2 v0.4s, v4.8h, v3.8h
-; CHECK-DOT-NEXT: smlal v0.4s, v2.4h, v1.4h
-; CHECK-DOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-DOT-NEXT: movi v3.16b, #128
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: eor v2.16b, v2.16b, v3.16b
+; CHECK-DOT-NEXT: udot v4.4s, v3.16b, v1.16b
+; CHECK-DOT-NEXT: udot v0.4s, v2.16b, v1.16b
+; CHECK-DOT-NEXT: sub v0.4s, v0.4s, v4.4s
; CHECK-DOT-NEXT: ret
;
; CHECK-DOT-I8MM-LABEL: usdot:
@@ -258,23 +256,22 @@ define <4 x i32> @usdot_in_loop(ptr %p1, ptr %p2){
;
; CHECK-DOT-LABEL: usdot_in_loop:
; CHECK-DOT: // %bb.0: // %entry
-; CHECK-DOT-NEXT: movi v1.2d, #0000000000000000
+; CHECK-DOT-NEXT: movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT: movi v1.16b, #128
; CHECK-DOT-NEXT: mov x8, xzr
; CHECK-DOT-NEXT: .LBB6_1: // %vector.body
; CHECK-DOT-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECK-DOT-NEXT: ldr q2, [x0, x8]
-; CHECK-DOT-NEXT: ldr q3, [x1, x8]
-; CHECK-DOT-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-NEXT: ldr q0, [x0, x8]
+; CHECK-DOT-NEXT: movi v3.2d, #0000000000000000
+; CHECK-DOT-NEXT: eor v4.16b, v0.16b, v1.16b
+; CHECK-DOT-NEXT: mov v0.16b, v2.16b
+; CHECK-DOT-NEXT: ldr q2, [x1, x8]
; CHECK-DOT-NEXT: add x8, x8, #16
-; CHECK-DOT-NEXT: sshll v4.8h, v2.8b, #0
-; CHECK-DOT-NEXT: ushll v5.8h, v3.8b, #0
-; CHECK-DOT-NEXT: sshll2 v2.8h, v2.16b, #0
-; CHECK-DOT-NEXT: ushll2 v3.8h, v3.16b, #0
+; CHECK-DOT-NEXT: udot v3.4s, v1.16b, v2.16b
; CHECK-DOT-NEXT: cmp x8, #16
-; CHECK-DOT-NEXT: smlal v1.4s, v4.4h, v5.4h
-; CHECK-DOT-NEXT: smlal2 v1.4s, v4.8h, v5.8h
-; CHECK-DOT-NEXT: smlal v1.4s, v2.4h, v3.4h
-; CHECK-DOT-NEXT: smlal2 v1.4s, v2.8h, v3.8h
+; CHECK-DOT-NEXT: mov v5.16b, v0.16b
+; CHECK-DOT-NEXT: udot v5.4s, v4.16b, v2.16b
+; CHECK-DOT-NEXT: sub v2.4s, v5.4s, v3.4s
; CHECK-DOT-NEXT: b.ne .LBB6_1
; CHECK-DOT-NEXT: // %bb.2: // %end
; CHECK-DOT-NEXT: ret
@@ -316,6 +313,92 @@ end:
ret <4 x i32> %acc
}
+; Same as @usdot_in_loop, but with a non-zero initial accumulator carried
+; across iterations. Validates that the SUMLA -> 2x udot identity composes
+; correctly when the accumulator is not zero on entry.
+define <4 x i32> @usdot_in_loop_nonzero_acc(ptr %p1, ptr %p2, <4 x i32> %init){
+; CHECK-NODOT-LABEL: usdot_in_loop_nonzero_acc:
+; CHECK-NODOT: // %bb.0: // %entry
+; CHECK-NODOT-NEXT: mov x8, xzr
+; CHECK-NODOT-NEXT: .LBB7_1: // %vector.body
+; CHECK-NODOT-NEXT: // =>This Inner Loop Header: Depth=1
+; CHECK-NODOT-NEXT: ldr q2, [x0, x8]
+; CHECK-NODOT-NEXT: ldr q3, [x1, x8]
+; CHECK-NODOT-NEXT: mov v1.16b, v0.16b
+; CHECK-NODOT-NEXT: add x8, x8, #16
+; CHECK-NODOT-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: ushll v5.8h, v3.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: ushll2 v3.8h, v3.16b, #0
+; CHECK-NODOT-NEXT: cmp x8, #16
+; CHECK-NODOT-NEXT: smlal v0.4s, v4.4h, v5.4h
+; CHECK-NODOT-NEXT: smlal2 v0.4s, v4.8h, v5.8h
+; CHECK-NODOT-NEXT: smlal v0.4s, v2.4h, v3.4h
+; CHECK-NODOT-NEXT: smlal2 v0.4s, v2.8h, v3.8h
+; CHECK-NODOT-NEXT: b.ne .LBB7_1
+; CHECK-NODOT-NEXT: // %bb.2: // %end
+; CHECK-NODOT-NEXT: mov v0.16b, v1.16b
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-DOT-LABEL: usdot_in_loop_nonzero_acc:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.16b, #128
+; CHECK-DOT-NEXT: mov x8, xzr
+; CHECK-DOT-NEXT: .LBB7_1: // %vector.body
+; CHECK-DOT-NEXT: // =>This Inner Loop Header: Depth=1
+; CHECK-DOT-NEXT: ldr q1, [x0, x8]
+; CHECK-DOT-NEXT: movi v3.2d, #0000000000000000
+; CHECK-DOT-NEXT: eor v4.16b, v1.16b, v2.16b
+; CHECK-DOT-NEXT: mov v1.16b, v0.16b
+; CHECK-DOT-NEXT: ldr q0, [x1, x8]
+; CHECK-DOT-NEXT: add x8, x8, #16
+; CHECK-DOT-NEXT: udot v3.4s, v2.16b, v0.16b
+; CHECK-DOT-NEXT: cmp x8, #16
+; CHECK-DOT-NEXT: mov v5.16b, v1.16b
+; CHECK-DOT-NEXT: udot v5.4s, v4.16b, v0.16b
+; CHECK-DOT-NEXT: sub v0.4s, v5.4s, v3.4s
+; CHECK-DOT-NEXT: b.ne .LBB7_1
+; CHECK-DOT-NEXT: // %bb.2: // %end
+; CHECK-DOT-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-DOT-I8MM-LABEL: usdot_in_loop_nonzero_acc:
+; CHECK-DOT-I8MM: // %bb.0: // %entry
+; CHECK-DOT-I8MM-NEXT: mov x8, xzr
+; CHECK-DOT-I8MM-NEXT: .LBB7_1: // %vector.body
+; CHECK-DOT-I8MM-NEXT: // =>This Inner Loop Header: Depth=1
+; CHECK-DOT-I8MM-NEXT: ldr q2, [x0, x8]
+; CHECK-DOT-I8MM-NEXT: ldr q3, [x1, x8]
+; CHECK-DOT-I8MM-NEXT: mov v1.16b, v0.16b
+; CHECK-DOT-I8MM-NEXT: add x8, x8, #16
+; CHECK-DOT-I8MM-NEXT: usdot v0.4s, v3.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT: cmp x8, #16
+; CHECK-DOT-I8MM-NEXT: b.ne .LBB7_1
+; CHECK-DOT-I8MM-NEXT: // %bb.2: // %end
+; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-I8MM-NEXT: ret
+entry:
+ br label %vector.body
+
+vector.body:
+ %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+ %acc = phi <4 x i32> [ %init, %entry ], [ %partial.reduce, %vector.body ]
+ %gep1 = getelementptr i8, ptr %p1, i64 %index
+ %load1 = load <16 x i8>, ptr %gep1, align 16
+ %load1.wide = sext <16 x i8> %load1 to <16 x i32>
+ %gep2 = getelementptr i8, ptr %p2, i64 %index
+ %load2 = load <16 x i8>, ptr %gep2, align 16
+ %load2.wide = zext <16 x i8> %load2 to <16 x i32>
+ %mul = mul nuw nsw <16 x i32> %load1.wide, %load2.wide
+ %partial.reduce = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mul)
+ %index.next = add nuw i64 %index, 16
+ %cmp = icmp eq i64 %index.next, 16
+ br i1 %cmp, label %end, label %vector.body
+
+end:
+ ret <4 x i32> %acc
+}
+
define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-NODOT-LABEL: usdot_narrow:
; CHECK-NODOT: // %bb.0:
@@ -336,19 +419,12 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
;
; CHECK-DOT-LABEL: usdot_narrow:
; CHECK-DOT: // %bb.0:
-; CHECK-DOT-NEXT: ushll v1.8h, v1.8b, #0
-; CHECK-DOT-NEXT: sshll v2.8h, v2.8b, #0
-; CHECK-DOT-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-DOT-NEXT: smull v3.4s, v2.4h, v1.4h
-; CHECK-DOT-NEXT: smlal v0.4s, v2.4h, v1.4h
-; CHECK-DOT-NEXT: mov d4, v1.d[1]
-; CHECK-DOT-NEXT: mov d5, v2.d[1]
-; CHECK-DOT-NEXT: smull2 v1.4s, v2.8h, v1.8h
-; CHECK-DOT-NEXT: mov d3, v3.d[1]
-; CHECK-DOT-NEXT: mov d1, v1.d[1]
-; CHECK-DOT-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-DOT-NEXT: smlal v0.4s, v5.4h, v4.4h
-; CHECK-DOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-DOT-NEXT: movi v3.8b, #128
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: eor v2.8b, v2.8b, v3.8b
+; CHECK-DOT-NEXT: udot v4.2s, v3.8b, v1.8b
+; CHECK-DOT-NEXT: udot v0.2s, v2.8b, v1.8b
+; CHECK-DOT-NEXT: sub v0.2s, v0.2s, v4.2s
; CHECK-DOT-NEXT: ret
;
; CHECK-DOT-I8MM-LABEL: usdot_narrow:
@@ -377,14 +453,12 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
;
; CHECK-DOT-LABEL: sudot:
; CHECK-DOT: // %bb.0:
-; CHECK-DOT-NEXT: sshll v3.8h, v1.8b, #0
-; CHECK-DOT-NEXT: ushll v4.8h, v2.8b, #0
-; CHECK-DOT-NEXT: sshll2 v1.8h, v1.16b, #0
-; CHECK-DOT-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-DOT-NEXT: smlal v0.4s, v4.4h, v3.4h
-; CHECK-DOT-NEXT: smlal2 v0.4s, v4.8h, v3.8h
-; CHECK-DOT-NEXT: smlal v0.4s, v2.4h, v1.4h
-; CHECK-DOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-DOT-NEXT: movi v3.16b, #128
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: eor v1.16b, v1.16b, v3.16b
+; CHECK-DOT-NEXT: udot v4.4s, v3.16b, v2.16b
+; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT: sub v0.4s, v0.4s, v4.4s
; CHECK-DOT-NEXT: ret
;
; CHECK-DOT-I8MM-LABEL: sudot:
@@ -403,7 +477,7 @@ define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
; CHECK-NODOT: // %bb.0: // %entry
; CHECK-NODOT-NEXT: movi v1.2d, #0000000000000000
; CHECK-NODOT-NEXT: mov x8, xzr
-; CHECK-NODOT-NEXT: .LBB9_1: // %vector.body
+; CHECK-NODOT-NEXT: .LBB10_1: // %vector.body
; CHECK-NODOT-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NODOT-NEXT: ldr q2, [x0, x8]
; CHECK-NODOT-NEXT: ldr q3, [x1, x8]
@@ -418,30 +492,29 @@ define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
; CHECK-NODOT-NEXT: smlal2 v1.4s, v4.8h, v5.8h
; CHECK-NODOT-NEXT: smlal v1.4s, v2.4h, v3.4h
; CHECK-NODOT-NEXT: smlal2 v1.4s, v2.8h, v3.8h
-; CHECK-NODOT-NEXT: b.ne .LBB9_1
+; CHECK-NODOT-NEXT: b.ne .LBB10_1
; CHECK-NODOT-NEXT: // %bb.2: // %end
; CHECK-NODOT-NEXT: ret
;
; CHECK-DOT-LABEL: sudot_in_loop:
; CHECK-DOT: // %bb.0: // %entry
-; CHECK-DOT-NEXT: movi v1.2d, #0000000000000000
+; CHECK-DOT-NEXT: movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT: movi v1.16b, #128
; CHECK-DOT-NEXT: mov x8, xzr
-; CHECK-DOT-NEXT: .LBB9_1: // %vector.body
+; CHECK-DOT-NEXT: .LBB10_1: // %vector.body
; CHECK-DOT-NEXT: // =>This Inner Loop Header: Depth=1
+; CHECK-DOT-NEXT: ldr q0, [x1, x8]
+; CHECK-DOT-NEXT: movi v3.2d, #0000000000000000
+; CHECK-DOT-NEXT: eor v4.16b, v0.16b, v1.16b
+; CHECK-DOT-NEXT: mov v0.16b, v2.16b
; CHECK-DOT-NEXT: ldr q2, [x0, x8]
-; CHECK-DOT-NEXT: ldr q3, [x1, x8]
-; CHECK-DOT-NEXT: mov v0.16b, v1.16b
; CHECK-DOT-NEXT: add x8, x8, #16
-; CHECK-DOT-NEXT: ushll v4.8h, v2.8b, #0
-; CHECK-DOT-NEXT: sshll v5.8h, v3.8b, #0
-; CHECK-DOT-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-DOT-NEXT: sshll2 v3.8h, v3.16b, #0
+; CHECK-DOT-NEXT: udot v3.4s, v1.16b, v2.16b
; CHECK-DOT-NEXT: cmp x8, #16
-; CHECK-DOT-NEXT: smlal v1.4s, v4.4h, v5.4h
-; CHECK-DOT-NEXT: smlal2 v1.4s, v4.8h, v5.8h
-; CHECK-DOT-NEXT: smlal v1.4s, v2.4h, v3.4h
-; CHECK-DOT-NEXT: smlal2 v1.4s, v2.8h, v3.8h
-; CHECK-DOT-NEXT: b.ne .LBB9_1
+; CHECK-DOT-NEXT: mov v5.16b, v0.16b
+; CHECK-DOT-NEXT: udot v5.4s, v4.16b, v2.16b
+; CHECK-DOT-NEXT: sub v2.4s, v5.4s, v3.4s
+; CHECK-DOT-NEXT: b.ne .LBB10_1
; CHECK-DOT-NEXT: // %bb.2: // %end
; CHECK-DOT-NEXT: ret
;
@@ -449,7 +522,7 @@ define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
; CHECK-DOT-I8MM: // %bb.0: // %entry
; CHECK-DOT-I8MM-NEXT: movi v1.2d, #0000000000000000
; CHECK-DOT-I8MM-NEXT: mov x8, xzr
-; CHECK-DOT-I8MM-NEXT: .LBB9_1: // %vector.body
+; CHECK-DOT-I8MM-NEXT: .LBB10_1: // %vector.body
; CHECK-DOT-I8MM-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-DOT-I8MM-NEXT: ldr q2, [x0, x8]
; CHECK-DOT-I8MM-NEXT: ldr q3, [x1, x8]
@@ -457,7 +530,7 @@ define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
; CHECK-DOT-I8MM-NEXT: add x8, x8, #16
; CHECK-DOT-I8MM-NEXT: usdot v1.4s, v2.16b, v3.16b
; CHECK-DOT-I8MM-NEXT: cmp x8, #16
-; CHECK-DOT-I8MM-NEXT: b.ne .LBB9_1
+; CHECK-DOT-I8MM-NEXT: b.ne .LBB10_1
; CHECK-DOT-I8MM-NEXT: // %bb.2: // %end
; CHECK-DOT-I8MM-NEXT: ret
entry:
@@ -502,19 +575,12 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
;
; CHECK-DOT-LABEL: sudot_narrow:
; CHECK-DOT: // %bb.0:
-; CHECK-DOT-NEXT: sshll v1.8h, v1.8b, #0
-; CHECK-DOT-NEXT: ushll v2.8h, v2.8b, #0
-; CHECK-DOT-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-DOT-NEXT: smull v3.4s, v2.4h, v1.4h
-; CHECK-DOT-NEXT: smlal v0.4s, v2.4h, v1.4h
-; CHECK-DOT-NEXT: mov d4, v1.d[1]
-; CHECK-DOT-NEXT: mov d5, v2.d[1]
-; CHECK-DOT-NEXT: smull2 v1.4s, v2.8h, v1.8h
-; CHECK-DOT-NEXT: mov d3, v3.d[1]
-; CHECK-DOT-NEXT: mov d1, v1.d[1]
-; CHECK-DOT-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-DOT-NEXT: smlal v0.4s, v5.4h, v4.4h
-; CHECK-DOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-DOT-NEXT: movi v3.8b, #128
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: eor v1.8b, v1.8b, v3.8b
+; CHECK-DOT-NEXT: udot v4.2s, v3.8b, v2.8b
+; CHECK-DOT-NEXT: udot v0.2s, v1.8b, v2.8b
+; CHECK-DOT-NEXT: sub v0.2s, v0.2s, v4.2s
; CHECK-DOT-NEXT: ret
;
; CHECK-DOT-I8MM-LABEL: sudot_narrow:
@@ -641,26 +707,15 @@ define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
;
; CHECK-DOT-LABEL: usdot_8to64:
; CHECK-DOT: // %bb.0: // %entry
-; CHECK-DOT-NEXT: ushll v4.8h, v2.8b, #0
-; CHECK-DOT-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-DOT-NEXT: sshll v5.8h, v3.8b, #0
-; CHECK-DOT-NEXT: sshll2 v3.8h, v3.16b, #0
-; CHECK-DOT-NEXT: ushll v6.4s, v4.4h, #0
-; CHECK-DOT-NEXT: ushll v7.4s, v2.4h, #0
-; CHECK-DOT-NEXT: sshll v16.4s, v5.4h, #0
-; CHECK-DOT-NEXT: sshll v17.4s, v3.4h, #0
-; CHECK-DOT-NEXT: ushll2 v4.4s, v4.8h, #0
-; CHECK-DOT-NEXT: ushll2 v2.4s, v2.8h, #0
-; CHECK-DOT-NEXT: sshll2 v5.4s, v5.8h, #0
-; CHECK-DOT-NEXT: sshll2 v3.4s, v3.8h, #0
-; CHECK-DOT-NEXT: smlal v0.2d, v6.2s, v16.2s
-; CHECK-DOT-NEXT: smlal v1.2d, v7.2s, v17.2s
-; CHECK-DOT-NEXT: smlal2 v0.2d, v6.4s, v16.4s
-; CHECK-DOT-NEXT: smlal2 v1.2d, v7.4s, v17.4s
-; CHECK-DOT-NEXT: smlal v0.2d, v4.2s, v5.2s
-; CHECK-DOT-NEXT: smlal v1.2d, v2.2s, v3.2s
-; CHECK-DOT-NEXT: smlal2 v0.2d, v4.4s, v5.4s
-; CHECK-DOT-NEXT: smlal2 v1.2d, v2.4s, v3.4s
+; CHECK-DOT-NEXT: movi v4.16b, #128
+; CHECK-DOT-NEXT: movi v5.2d, #0000000000000000
+; CHECK-DOT-NEXT: movi v6.2d, #0000000000000000
+; CHECK-DOT-NEXT: eor v3.16b, v3.16b, v4.16b
+; CHECK-DOT-NEXT: udot v6.4s, v4.16b, v2.16b
+; CHECK-DOT-NEXT: udot v5.4s, v3.16b, v2.16b
+; CHECK-DOT-NEXT: sub v2.4s, v5.4s, v6.4s
+; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v2.2s
+; CHECK-DOT-NEXT: saddw2 v0.2d, v0.2d, v2.4s
; CHECK-DOT-NEXT: ret
;
; CHECK-DOT-I8MM-LABEL: usdot_8to64:
@@ -706,26 +761,15 @@ define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
;
; CHECK-DOT-LABEL: sudot_8to64:
; CHECK-DOT: // %bb.0: // %entry
-; CHECK-DOT-NEXT: sshll v4.8h, v2.8b, #0
-; CHECK-DOT-NEXT: sshll2 v2.8h, v2.16b, #0
-; CHECK-DOT-NEXT: ushll v5.8h, v3.8b, #0
-; CHECK-DOT-NEXT: ushll2 v3.8h, v3.16b, #0
-; CHECK-DOT-NEXT: sshll v6.4s, v4.4h, #0
-; CHECK-DOT-NEXT: sshll v7.4s, v2.4h, #0
-; CHECK-DOT-NEXT: ushll v16.4s, v5.4h, #0
-; CHECK-DOT-NEXT: ushll v17.4s, v3.4h, #0
-; CHECK-DOT-NEXT: sshll2 v4.4s, v4.8h, #0
-; CHECK-DOT-NEXT: sshll2 v2.4s, v2.8h, #0
-; CHECK-DOT-NEXT: ushll2 v5.4s, v5.8h, #0
-; CHECK-DOT-NEXT: ushll2 v3.4s, v3.8h, #0
-; CHECK-DOT-NEXT: smlal v0.2d, v6.2s, v16.2s
-; CHECK-DOT-NEXT: smlal v1.2d, v7.2s, v17.2s
-; CHECK-DOT-NEXT: smlal2 v0.2d, v6.4s, v16.4s
-; CHECK-DOT-NEXT: smlal2 v1.2d, v7.4s, v17.4s
-; CHECK-DOT-NEXT: smlal v0.2d, v4.2s, v5.2s
-; CHECK-DOT-NEXT: smlal v1.2d, v2.2s, v3.2s
-; CHECK-DOT-NEXT: smlal2 v0.2d, v4.4s, v5.4s
-; CHECK-DOT-NEXT: smlal2 v1.2d, v2.4s, v3.4s
+; CHECK-DOT-NEXT: movi v4.16b, #128
+; CHECK-DOT-NEXT: movi v5.2d, #0000000000000000
+; CHECK-DOT-NEXT: movi v6.2d, #0000000000000000
+; CHECK-DOT-NEXT: eor v2.16b, v2.16b, v4.16b
+; CHECK-DOT-NEXT: udot v6.4s, v4.16b, v3.16b
+; CHECK-DOT-NEXT: udot v5.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT: sub v2.4s, v5.4s, v6.4s
+; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v2.2s
+; CHECK-DOT-NEXT: saddw2 v0.2d, v0.2d, v2.4s
; CHECK-DOT-NEXT: ret
;
; CHECK-DOT-I8MM-LABEL: sudot_...
[truncated]
|
🐧 Linux x64 Test Results
Failed Tests(click on a test name to see its output) lldb-apilldb-api.python_api/run_locker/TestRunLocker.pyIf these failures are unrelated to your changes (for example tests are broken or flaky at HEAD), please open an issue at https://github.com/llvm/llvm-project/issues and add the |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
sdesmalen-arm
left a comment
There was a problem hiding this comment.
LGTM, but can this be generalised to also work for scalable vectors?
I don't see any reason why not. The main reason I did not include it here is that I though almost all CPUs supporting SVE would also support i8mm. |
It indeed seems that all v9 ones have it (and most of the v8 with SVE), so it's probably not worth it. |
Lower PARTIAL_REDUCE_SUMLA on targets without i8mm using udot via
sum(mul (zext %a, sext %b)) =
sum(mul zext %a, zext (xor %b, 0x80)) -
sum(mul 128, zext %a)
using sext(s) = zext(xor(s, 0x80)) - 128
(https://alive2.llvm.org/ce/z/8G8rer)
Alive2 proofs: