Skip to content

[LLVM][CodeGen][SVE] Lower bfloat ISD::VECREDUCE_SEQ_FADD operations.#199732

Open
paulwalker-arm wants to merge 1 commit into
llvm:mainfrom
paulwalker-arm:sve-bfloat-reductions-ordered
Open

[LLVM][CodeGen][SVE] Lower bfloat ISD::VECREDUCE_SEQ_FADD operations.#199732
paulwalker-arm wants to merge 1 commit into
llvm:mainfrom
paulwalker-arm:sve-bfloat-reductions-ordered

Conversation

@paulwalker-arm
Copy link
Copy Markdown
Contributor

No description provided.

@llvmorg-github-actions
Copy link
Copy Markdown

llvmorg-github-actions Bot commented May 26, 2026

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/199732.diff

4 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+22)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+14-2)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+7)
  • (added) llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll (+49)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index b2147d37031ab..1f0b39cd85e1e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -191,6 +191,8 @@ class VectorLegalizer {
   /// rounding of the result does not affect its value.
   void PromoteFloatVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results,
                              bool NonArithmetic);
+  void PromoteFloatVECREDUCE_SEQ(SDNode *Node,
+                                 SmallVectorImpl<SDValue> &Results);
 
   void PromoteVECTOR_COMPRESS(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
@@ -722,6 +724,23 @@ void VectorLegalizer::PromoteFloatVECREDUCE(SDNode *Node,
   Results.push_back(Res);
 }
 
+void VectorLegalizer::PromoteFloatVECREDUCE_SEQ(
+    SDNode *Node, SmallVectorImpl<SDValue> &Results) {
+  MVT OrigVecVT = Node->getOperand(1).getSimpleValueType();
+  assert(OrigVecVT.isFloatingPoint() && "Expected floating point reduction!");
+  MVT VecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OrigVecVT);
+  MVT EltVT = VecVT.getVectorElementType();
+
+  SDLoc DL(Node);
+  SDValue EltOp = DAG.getNode(ISD::FP_EXTEND, DL, EltVT, Node->getOperand(0));
+  SDValue VecOp = DAG.getNode(ISD::FP_EXTEND, DL, VecVT, Node->getOperand(1));
+  SDValue Rdx =
+      DAG.getNode(Node->getOpcode(), DL, EltVT, EltOp, VecOp, Node->getFlags());
+  SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
+                            DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+  Results.push_back(Res);
+}
+
 void VectorLegalizer::PromoteVECTOR_COMPRESS(
     SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   SDLoc DL(Node);
@@ -790,6 +809,9 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::VECREDUCE_FMINIMUM:
     PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/true);
     return;
+  case ISD::VECREDUCE_SEQ_FADD:
+    PromoteFloatVECREDUCE_SEQ(Node, Results);
+    return;
   case ISD::VECTOR_COMPRESS:
     PromoteVECTOR_COMPRESS(Node, Results);
     return;
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 48fe1eed9093f..a365c346d7cfc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12659,7 +12659,7 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
 }
 
 SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const {
-  SDLoc dl(Node);
+  SDLoc DL(Node);
   SDValue AccOp = Node->getOperand(0);
   SDValue VecOp = Node->getOperand(1);
   SDNodeFlags Flags = Node->getFlags();
@@ -12667,6 +12667,18 @@ SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) cons
   EVT VT = VecOp.getValueType();
   EVT EltVT = VT.getVectorElementType();
 
+  if (VT.getVectorElementCount().isKnownMultipleOf(2)) {
+    auto [LoVecVT, HiVecVT] = DAG.GetSplitDestVTs(VT);
+    if (isOperationLegalOrCustomOrPromote(Node->getOpcode(), LoVecVT) &&
+        isOperationLegalOrCustomOrPromote(Node->getOpcode(), HiVecVT)) {
+      auto [LoVecOp, HiVecOp] = DAG.SplitVector(VecOp, DL, LoVecVT, HiVecVT);
+
+      unsigned Opcode = Node->getOpcode();
+      SDValue ReduceLo = DAG.getNode(Opcode, DL, EltVT, AccOp, LoVecOp, Flags);
+      return DAG.getNode(Opcode, DL, EltVT, ReduceLo, HiVecOp, Flags);
+    }
+  }
+
   if (VT.isScalableVector())
     report_fatal_error(
         "Expanding reductions for scalable vectors is undefined.");
@@ -12680,7 +12692,7 @@ SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) cons
 
   SDValue Res = AccOp;
   for (unsigned i = 0; i < NumElts; i++)
-    Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Flags);
+    Res = DAG.getNode(BaseOpcode, DL, EltVT, Res, Ops[i], Flags);
 
   return Res;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 08f3d4e0d30ac..22cd8d0db512d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2134,6 +2134,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                     MVT::v2f32, MVT::v4f32, MVT::v2f64})
       setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
 
+    setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv2bf16,
+                               MVT::nxv2f32);
+    setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv4bf16,
+                               MVT::nxv4f32);
+    setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv8bf16,
+                               MVT::nxv8f32);
+
     // We can lower types that have <vscale x {2|4}> elements to compact.
     for (auto VT :
          {MVT::nxv4i32, MVT::nxv2i64, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv2f64})
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll
new file mode 100644
index 0000000000000..e6bca48283760
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll
@@ -0,0 +1,49 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define bfloat @fadd_ordered_nxv2bf16(bfloat %a, <vscale x 2 x bfloat> %b) {
+; CHECK-LABEL: fadd_ordered_nxv2bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $d0
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    shll v0.4s, v0.4h, #16
+; CHECK-NEXT:    fadda s0, p0, s0, z1.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, <vscale x 2 x bfloat> %b)
+  ret bfloat %res
+}
+
+define bfloat @fadd_ordered_nxv4bf16(bfloat %a, <vscale x 4 x bfloat> %b) {
+; CHECK-LABEL: fadd_ordered_nxv4bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $d0
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    shll v0.4s, v0.4h, #16
+; CHECK-NEXT:    fadda s0, p0, s0, z1.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, <vscale x 4 x bfloat> %b)
+  ret bfloat %res
+}
+
+define bfloat @fadd_ordered_nxv8bf16(bfloat %a, <vscale x 8 x bfloat> %b) {
+; CHECK-LABEL: fadd_ordered_nxv8bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $d0
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    shll v0.4s, v0.4h, #16
+; CHECK-NEXT:    zip1 z3.h, z2.h, z1.h
+; CHECK-NEXT:    zip2 z1.h, z2.h, z1.h
+; CHECK-NEXT:    fadda s0, p0, s0, z3.s
+; CHECK-NEXT:    fadda s0, p0, s0, z1.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, <vscale x 8 x bfloat> %b)
+  ret bfloat %res
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant