[LLVM][CodeGen][SVE] Lower bfloat ISD::VECREDUCE_SEQ_FADD operations.#199732
Open
paulwalker-arm wants to merge 1 commit into
Open
[LLVM][CodeGen][SVE] Lower bfloat ISD::VECREDUCE_SEQ_FADD operations.#199732paulwalker-arm wants to merge 1 commit into
paulwalker-arm wants to merge 1 commit into
Conversation
|
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-aarch64 Author: Paul Walker (paulwalker-arm) ChangesFull diff: https://github.com/llvm/llvm-project/pull/199732.diff 4 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index b2147d37031ab..1f0b39cd85e1e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -191,6 +191,8 @@ class VectorLegalizer {
/// rounding of the result does not affect its value.
void PromoteFloatVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results,
bool NonArithmetic);
+ void PromoteFloatVECREDUCE_SEQ(SDNode *Node,
+ SmallVectorImpl<SDValue> &Results);
void PromoteVECTOR_COMPRESS(SDNode *Node, SmallVectorImpl<SDValue> &Results);
@@ -722,6 +724,23 @@ void VectorLegalizer::PromoteFloatVECREDUCE(SDNode *Node,
Results.push_back(Res);
}
+void VectorLegalizer::PromoteFloatVECREDUCE_SEQ(
+ SDNode *Node, SmallVectorImpl<SDValue> &Results) {
+ MVT OrigVecVT = Node->getOperand(1).getSimpleValueType();
+ assert(OrigVecVT.isFloatingPoint() && "Expected floating point reduction!");
+ MVT VecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OrigVecVT);
+ MVT EltVT = VecVT.getVectorElementType();
+
+ SDLoc DL(Node);
+ SDValue EltOp = DAG.getNode(ISD::FP_EXTEND, DL, EltVT, Node->getOperand(0));
+ SDValue VecOp = DAG.getNode(ISD::FP_EXTEND, DL, VecVT, Node->getOperand(1));
+ SDValue Rdx =
+ DAG.getNode(Node->getOpcode(), DL, EltVT, EltOp, VecOp, Node->getFlags());
+ SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
+ DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+ Results.push_back(Res);
+}
+
void VectorLegalizer::PromoteVECTOR_COMPRESS(
SDNode *Node, SmallVectorImpl<SDValue> &Results) {
SDLoc DL(Node);
@@ -790,6 +809,9 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::VECREDUCE_FMINIMUM:
PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/true);
return;
+ case ISD::VECREDUCE_SEQ_FADD:
+ PromoteFloatVECREDUCE_SEQ(Node, Results);
+ return;
case ISD::VECTOR_COMPRESS:
PromoteVECTOR_COMPRESS(Node, Results);
return;
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 48fe1eed9093f..a365c346d7cfc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12659,7 +12659,7 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
}
SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const {
- SDLoc dl(Node);
+ SDLoc DL(Node);
SDValue AccOp = Node->getOperand(0);
SDValue VecOp = Node->getOperand(1);
SDNodeFlags Flags = Node->getFlags();
@@ -12667,6 +12667,18 @@ SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) cons
EVT VT = VecOp.getValueType();
EVT EltVT = VT.getVectorElementType();
+ if (VT.getVectorElementCount().isKnownMultipleOf(2)) {
+ auto [LoVecVT, HiVecVT] = DAG.GetSplitDestVTs(VT);
+ if (isOperationLegalOrCustomOrPromote(Node->getOpcode(), LoVecVT) &&
+ isOperationLegalOrCustomOrPromote(Node->getOpcode(), HiVecVT)) {
+ auto [LoVecOp, HiVecOp] = DAG.SplitVector(VecOp, DL, LoVecVT, HiVecVT);
+
+ unsigned Opcode = Node->getOpcode();
+ SDValue ReduceLo = DAG.getNode(Opcode, DL, EltVT, AccOp, LoVecOp, Flags);
+ return DAG.getNode(Opcode, DL, EltVT, ReduceLo, HiVecOp, Flags);
+ }
+ }
+
if (VT.isScalableVector())
report_fatal_error(
"Expanding reductions for scalable vectors is undefined.");
@@ -12680,7 +12692,7 @@ SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) cons
SDValue Res = AccOp;
for (unsigned i = 0; i < NumElts; i++)
- Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Flags);
+ Res = DAG.getNode(BaseOpcode, DL, EltVT, Res, Ops[i], Flags);
return Res;
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 08f3d4e0d30ac..22cd8d0db512d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2134,6 +2134,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v2f32, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
+ setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv2bf16,
+ MVT::nxv2f32);
+ setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv4bf16,
+ MVT::nxv4f32);
+ setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv8bf16,
+ MVT::nxv8f32);
+
// We can lower types that have <vscale x {2|4}> elements to compact.
for (auto VT :
{MVT::nxv4i32, MVT::nxv2i64, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv2f64})
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll
new file mode 100644
index 0000000000000..e6bca48283760
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll
@@ -0,0 +1,49 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define bfloat @fadd_ordered_nxv2bf16(bfloat %a, <vscale x 2 x bfloat> %b) {
+; CHECK-LABEL: fadd_ordered_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $h0 killed $h0 def $d0
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: shll v0.4s, v0.4h, #16
+; CHECK-NEXT: fadda s0, p0, s0, z1.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, <vscale x 2 x bfloat> %b)
+ ret bfloat %res
+}
+
+define bfloat @fadd_ordered_nxv4bf16(bfloat %a, <vscale x 4 x bfloat> %b) {
+; CHECK-LABEL: fadd_ordered_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $h0 killed $h0 def $d0
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: shll v0.4s, v0.4h, #16
+; CHECK-NEXT: fadda s0, p0, s0, z1.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, <vscale x 4 x bfloat> %b)
+ ret bfloat %res
+}
+
+define bfloat @fadd_ordered_nxv8bf16(bfloat %a, <vscale x 8 x bfloat> %b) {
+; CHECK-LABEL: fadd_ordered_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v2.2d, #0000000000000000
+; CHECK-NEXT: // kill: def $h0 killed $h0 def $d0
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: shll v0.4s, v0.4h, #16
+; CHECK-NEXT: zip1 z3.h, z2.h, z1.h
+; CHECK-NEXT: zip2 z1.h, z2.h, z1.h
+; CHECK-NEXT: fadda s0, p0, s0, z3.s
+; CHECK-NEXT: fadda s0, p0, s0, z1.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, <vscale x 8 x bfloat> %b)
+ ret bfloat %res
+}
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.