Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ class VectorLegalizer {
/// rounding of the result does not affect its value.
void PromoteFloatVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results,
bool NonArithmetic);
void PromoteFloatVECREDUCE_SEQ(SDNode *Node,
SmallVectorImpl<SDValue> &Results);

void PromoteVECTOR_COMPRESS(SDNode *Node, SmallVectorImpl<SDValue> &Results);

Expand Down Expand Up @@ -722,6 +724,23 @@ void VectorLegalizer::PromoteFloatVECREDUCE(SDNode *Node,
Results.push_back(Res);
}

void VectorLegalizer::PromoteFloatVECREDUCE_SEQ(
SDNode *Node, SmallVectorImpl<SDValue> &Results) {
MVT OrigVecVT = Node->getOperand(1).getSimpleValueType();
assert(OrigVecVT.isFloatingPoint() && "Expected floating point reduction!");
MVT VecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OrigVecVT);
MVT EltVT = VecVT.getVectorElementType();

SDLoc DL(Node);
SDValue EltOp = DAG.getNode(ISD::FP_EXTEND, DL, EltVT, Node->getOperand(0));
SDValue VecOp = DAG.getNode(ISD::FP_EXTEND, DL, VecVT, Node->getOperand(1));
SDValue Rdx =
DAG.getNode(Node->getOpcode(), DL, EltVT, EltOp, VecOp, Node->getFlags());
SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
Results.push_back(Res);
}

void VectorLegalizer::PromoteVECTOR_COMPRESS(
SDNode *Node, SmallVectorImpl<SDValue> &Results) {
SDLoc DL(Node);
Expand Down Expand Up @@ -790,6 +809,9 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::VECREDUCE_FMINIMUM:
PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/true);
return;
case ISD::VECREDUCE_SEQ_FADD:
PromoteFloatVECREDUCE_SEQ(Node, Results);
return;
case ISD::VECTOR_COMPRESS:
PromoteVECTOR_COMPRESS(Node, Results);
return;
Expand Down
16 changes: 14 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12659,14 +12659,26 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
}

SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const {
SDLoc dl(Node);
SDLoc DL(Node);
SDValue AccOp = Node->getOperand(0);
SDValue VecOp = Node->getOperand(1);
SDNodeFlags Flags = Node->getFlags();

EVT VT = VecOp.getValueType();
EVT EltVT = VT.getVectorElementType();

if (VT.getVectorElementCount().isKnownMultipleOf(2)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add a little comment here to describe the intent?

I was actually surprised it went the 'expand' route here, rather than SplitVector legalisation step. Do you know why that is?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because the promotion is happening during operation legalisation rather than type legalisation. Essentially LegalizeVectorOps sometimes creates illegal types but they are legalised during the same phase (in this instance by expanding the promoted operation) so the illegal types do not escape.

auto [LoVecVT, HiVecVT] = DAG.GetSplitDestVTs(VT);
if (isOperationLegalOrCustomOrPromote(Node->getOpcode(), LoVecVT) &&
isOperationLegalOrCustomOrPromote(Node->getOpcode(), HiVecVT)) {
auto [LoVecOp, HiVecOp] = DAG.SplitVector(VecOp, DL, LoVecVT, HiVecVT);

unsigned Opcode = Node->getOpcode();
SDValue ReduceLo = DAG.getNode(Opcode, DL, EltVT, AccOp, LoVecOp, Flags);
return DAG.getNode(Opcode, DL, EltVT, ReduceLo, HiVecOp, Flags);
}
}

if (VT.isScalableVector())
report_fatal_error(
"Expanding reductions for scalable vectors is undefined.");
Expand All @@ -12680,7 +12692,7 @@ SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) cons

SDValue Res = AccOp;
for (unsigned i = 0; i < NumElts; i++)
Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Flags);
Res = DAG.getNode(BaseOpcode, DL, EltVT, Res, Ops[i], Flags);

return Res;
}
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2134,6 +2134,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v2f32, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);

setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv2bf16,
MVT::nxv2f32);
setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv4bf16,
MVT::nxv4f32);
setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv8bf16,
MVT::nxv8f32);

// We can lower types that have <vscale x {2|4}> elements to compact.
for (auto VT :
{MVT::nxv4i32, MVT::nxv2i64, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv2f64})
Expand Down
49 changes: 49 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s

target triple = "aarch64-unknown-linux-gnu"

define bfloat @fadd_ordered_nxv2bf16(bfloat %a, <vscale x 2 x bfloat> %b) {
; CHECK-LABEL: fadd_ordered_nxv2bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $h0 killed $h0 def $d0
; CHECK-NEXT: lsl z1.s, z1.s, #16
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: shll v0.4s, v0.4h, #16
; CHECK-NEXT: fadda s0, p0, s0, z1.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, <vscale x 2 x bfloat> %b)
ret bfloat %res
}

define bfloat @fadd_ordered_nxv4bf16(bfloat %a, <vscale x 4 x bfloat> %b) {
; CHECK-LABEL: fadd_ordered_nxv4bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $h0 killed $h0 def $d0
; CHECK-NEXT: lsl z1.s, z1.s, #16
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: shll v0.4s, v0.4h, #16
; CHECK-NEXT: fadda s0, p0, s0, z1.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, <vscale x 4 x bfloat> %b)
ret bfloat %res
}

define bfloat @fadd_ordered_nxv8bf16(bfloat %a, <vscale x 8 x bfloat> %b) {
; CHECK-LABEL: fadd_ordered_nxv8bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v2.2d, #0000000000000000
; CHECK-NEXT: // kill: def $h0 killed $h0 def $d0
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: shll v0.4s, v0.4h, #16
; CHECK-NEXT: zip1 z3.h, z2.h, z1.h
; CHECK-NEXT: zip2 z1.h, z2.h, z1.h
; CHECK-NEXT: fadda s0, p0, s0, z3.s
; CHECK-NEXT: fadda s0, p0, s0, z1.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, <vscale x 8 x bfloat> %b)
ret bfloat %res
}