[LV] Bundle sub reductions into VPExpressionRecipe#147255
Conversation
|
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-backend-arm Author: Sam Tebbs (SamTebbs33) ChangesThis PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account. Patch is 23.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147255.diff 14 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index c43870392361d..3cc0ea01953c3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1645,8 +1645,10 @@ class TargetTransformInfo {
/// extensions. This is the cost of as:
/// ResTy vecreduce.add(mul (A, B)).
/// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
+ /// The multiply can optionally be negated, which signifies that it is a sub
+ /// reduction.
LLVM_ABI InstructionCost getMulAccReductionCost(
- bool IsUnsigned, Type *ResTy, VectorType *Ty,
+ bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
/// Calculate the cost of an extended reduction pattern, similar to
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 12f87226c5f57..fd22981a5dbf3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -960,7 +960,7 @@ class TargetTransformInfoImplBase {
virtual InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
- TTI::TargetCostKind CostKind) const {
+ bool Negated, TTI::TargetCostKind CostKind) const {
return 1;
}
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index bf958e100f2ac..a9c9fa6d1db0d 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3116,7 +3116,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
+ bool Negated,
TTI::TargetCostKind CostKind) const override {
+ if (Negated)
+ return InstructionCost::getInvalid(CostKind);
// Without any native support, this is equivalent to the cost of
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
// vecreduce.add(mul(A, B)).
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 3ebd9d487ba04..ba0d070bffe6d 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1274,9 +1274,10 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
}
InstructionCost TargetTransformInfo::getMulAccReductionCost(
- bool IsUnsigned, Type *ResTy, VectorType *Ty,
+ bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
TTI::TargetCostKind CostKind) const {
- return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind);
+ return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, Negated,
+ CostKind);
}
InstructionCost
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 380faa6cf6939..d9a367535baf4 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5316,8 +5316,10 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
InstructionCost
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
- VectorType *VecTy,
+ VectorType *VecTy, bool Negated,
TTI::TargetCostKind CostKind) const {
+ if (Negated)
+ return InstructionCost::getInvalid(CostKind);
EVT VecVT = TLI->getValueType(DL, VecTy);
EVT ResVT = TLI->getValueType(DL, ResTy);
@@ -5332,7 +5334,8 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
return LT.first + 2;
}
- return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
+ return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, Negated,
+ CostKind);
}
InstructionCost
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 9ada70bd7086a..8bb31d2a3dac5 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -447,7 +447,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
TTI::TargetCostKind CostKind) const override;
InstructionCost getMulAccReductionCost(
- bool IsUnsigned, Type *ResTy, VectorType *Ty,
+ bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;
InstructionCost
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 203fb76d7be86..27eb22b5f9986 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1884,8 +1884,10 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
InstructionCost
ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
- VectorType *ValTy,
+ VectorType *ValTy, bool Negated,
TTI::TargetCostKind CostKind) const {
+ if (Negated)
+ return InstructionCost::getInvalid(CostKind);
EVT ValVT = TLI->getValueType(DL, ValTy);
EVT ResVT = TLI->getValueType(DL, ResTy);
@@ -1906,7 +1908,8 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
return ST->getMVEVectorCostFactor(CostKind) * LT.first;
}
- return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, CostKind);
+ return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, Negated,
+ CostKind);
}
InstructionCost
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index ca06b9e3cb661..43f47f3e7aa6f 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -299,6 +299,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
TTI::TargetCostKind CostKind) const override;
InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *ValTy,
+ bool Negated,
TTI::TargetCostKind CostKind) const override;
InstructionCost
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1cfbcf1336620..0adff8d957e98 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5538,7 +5538,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI::CastContextHint::None, CostKind, RedOp);
InstructionCost RedCost = TTI.getMulAccReductionCost(
- IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
+ IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
if (RedCost.isValid() &&
RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
@@ -5583,7 +5583,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
InstructionCost RedCost = TTI.getMulAccReductionCost(
- IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
+ IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
InstructionCost ExtraExtCost = 0;
if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
@@ -5602,7 +5602,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
InstructionCost RedCost = TTI.getMulAccReductionCost(
- true, RdxDesc.getRecurrenceType(), VectorTy, CostKind);
+ true, RdxDesc.getRecurrenceType(), VectorTy, false, CostKind);
if (RedCost.isValid() && RedCost < MulCost + BaseCost)
return I == RetI ? RedCost : 0;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index d460573f5bec6..1bc926db301d8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2757,6 +2757,12 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
/// vector operands, performing a reduction.add on the result, and adding
/// the scalar result to a chain.
MulAccReduction,
+ /// Represent an inloop multiply-accumulate reduction, multiplying the
+ /// extended vector operands, negating the multiplication, performing a
+ /// reduction.add
+ /// on the result, and adding
+ /// the scalar result to a chain.
+ ExtNegatedMulAccReduction,
};
/// Type of the expression.
@@ -2780,6 +2786,11 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
VPWidenRecipe *Mul, VPReductionRecipe *Red)
: VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction,
{Ext0, Ext1, Mul, Red}) {}
+ VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
+ VPWidenRecipe *Mul, VPWidenRecipe *Sub,
+ VPReductionRecipe *Red)
+ : VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction,
+ {Ext0, Ext1, Mul, Sub, Red}) {}
~VPExpressionRecipe() override {
for (auto *R : reverse(ExpressionRecipes))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 318e8171e098d..c20b1920c3791 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2672,13 +2672,17 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
}
case ExpressionTypes::MulAccReduction:
- return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, Ctx.CostKind);
+ return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, false,
+ Ctx.CostKind);
- case ExpressionTypes::ExtMulAccReduction:
+ case ExpressionTypes::ExtNegatedMulAccReduction:
+ case ExpressionTypes::ExtMulAccReduction: {
+ bool Negated = ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction;
return Ctx.TTI.getMulAccReductionCost(
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
- RedTy, SrcVecTy, Ctx.CostKind);
+ RedTy, SrcVecTy, Negated, Ctx.CostKind);
+ }
}
llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
}
@@ -2725,6 +2729,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
O << ")";
break;
}
+ case ExpressionTypes::ExtNegatedMulAccReduction: {
+ getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
+ O << " + ";
+ O << "reduce."
+ << Instruction::getOpcodeName(
+ RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
+ << " (sub (0, mul";
+ auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
+ Mul->printFlags(O);
+ O << "(";
+ getOperand(0)->printAsOperand(O, SlotTracker);
+ auto *Ext0 = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
+ O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to "
+ << *Ext0->getResultType() << "), (";
+ getOperand(1)->printAsOperand(O, SlotTracker);
+ auto *Ext1 = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
+ O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to "
+ << *Ext1->getResultType() << ")";
+ if (Red->isConditional()) {
+ O << ", ";
+ Red->getCondOp()->printAsOperand(O, SlotTracker);
+ }
+ O << "))";
+ break;
+ }
case ExpressionTypes::MulAccReduction:
case ExpressionTypes::ExtMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 931d4d42f56e4..a09d2037e97b4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2908,16 +2908,17 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
// Clamp the range if using multiply-accumulate-reduction is profitable.
auto IsMulAccValidAndClampRange =
- [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
- VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+ [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+ VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
+ bool Negated = false) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *SrcTy =
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
- InstructionCost MulAccCost =
- Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
+ InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+ IsZExt, RedTy, SrcVecTy, Negated, CostKind);
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
InstructionCost ExtCost = 0;
@@ -2935,14 +2936,22 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
};
VPValue *VecOp = Red->getVecOp();
+ VPValue *Mul = nullptr;
+ VPValue *Sub = nullptr;
VPValue *A, *B;
+ // Sub reductions will have a sub between the add reduction and vec op.
+ if (match(VecOp,
+ m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul))))
+ Sub = VecOp;
+ else
+ Mul = VecOp;
// Try to match reduce.add(mul(...)).
- if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+ if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) {
auto *RecipeA =
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
auto *RecipeB =
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
- auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+ auto *MulR = cast<VPWidenRecipe>(Mul->getDefiningRecipe());
// Match reduce.add(mul(ext, ext)).
if (RecipeA && RecipeB &&
@@ -2951,12 +2960,16 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
Instruction::CastOps::ZExt,
- Mul, RecipeA, RecipeB, nullptr)) {
- return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
+ MulR, RecipeA, RecipeB, nullptr, Sub)) {
+ if (Sub)
+ return new VPExpressionRecipe(
+ RecipeA, RecipeB, MulR,
+ cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red);
+ return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
}
// Match reduce.add(mul).
- if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
- return new VPExpressionRecipe(Mul, Red);
+ if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr, Sub))
+ return new VPExpressionRecipe(MulR, Red);
}
// Match reduce.add(ext(mul(ext(A), ext(B)))).
// All extend recipes must have same opcode or A == B
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index b2fced47b9527..7953aec48c8b0 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1401,8 +1401,8 @@ static void analyzeCostOfVecReduction(const IntrinsicInst &II,
TTI::CastContextHint::None, CostKind, RedOp);
CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
- CostAfterReduction =
- TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
+ CostAfterReduction = TTI.getMulAccReductionCost(IsUnsigned, II.getType(),
+ ExtType, false, CostKind);
return;
}
CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
index 4af3fa9202c77..8059ac12ecd2e 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
@@ -416,3 +416,146 @@ exit:
%r.0.lcssa = phi i64 [ %rdx.next, %loop ]
ret i64 %r.0.lcssa
}
+
+define i32 @print_mulacc_sub(ptr %a, ptr %b) {
+; CHECK: VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF
+; CHECK-NEXT: Live-in vp<%1> = VF * UF
+; CHECK-NEXT: Live-in vp<%2> = vector-trip-count
+; CHECK-NEXT: Live-in ir<1024> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<entry>:
+; CHECK-NEXT: Successor(s): scalar.ph, vector.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: EMIT vp<%3> = reduction-start-vector ir<0>, ir<0>, ir<1>
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT: vector.body:
+; CHECK-NEXT: EMIT vp<%4> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%accum> = phi vp<%3>, vp<%8>
+; CHECK-NEXT: vp<%5> = SCALAR-STEPS vp<%4>, ir<1>, vp<%0>
+; CHECK-NEXT: CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%5>
+; CHECK-NEXT: vp<%6> = vector-pointer ir<%gep.a>
+; CHECK-NEXT: WIDEN ir<%load.a> = load vp<%6>
+; CHECK-NEXT: CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%5>
+; CHECK-NEXT: vp<%7> = vector-pointer ir<%gep.b>
+; CHECK-NEXT: WIDEN ir<%load.b> = load vp<%7>
+; CHECK-NEXT: EXPRESSION vp<%8> = ir<%accum> + reduce.add (sub (0, mul (ir<%load.b> zext to i32), (ir<%load.a> zext to i32)))
+; CHECK-NEXT: EMIT vp<%index.next> = add nuw vp<%4>, vp<%1>
+; CHECK-NEXT: EMIT branch-on-count vp<%index.next>, vp<%2>
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT: EMIT vp<%10> = compute-reduction-result ir<%accum>, vp<%8>
+; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<1024>, vp<%2>
+; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.exit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.exit>:
+; CHECK-NEXT: IR %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<%10> from middle.block)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT: EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<%2>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT: EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%10>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT: IR %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] (extra operand: vp<%bc.resume.val> from scalar.ph)
+; CHECK-NEXT: IR %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
+; CHECK-NEXT: IR %gep.a = getelementptr i8, ptr %a, i64 %iv
+; CHECK-NEXT: IR %load.a = load i8, ptr %gep.a, align 1
+; CHECK-NEXT: IR %ext.a = zext i8 %load.a to i32
+; CHECK-NEXT: IR %gep.b = getelementptr i8, ptr %b, i64 %iv
+; CHECK-NEXT: IR %load.b = load i8, ptr %gep.b, align 1
+; CHECK-NEXT: IR %ext.b = zext i8 %load.b to i32
+; CHECK-NEXT: IR %mul = mul i32 %ext.b, %ext.a
+; CHECK-NEXT: IR %add = sub i32 %accum, %mul
+; CHECK-NEXT: IR %iv.next = add i64 %iv, 1
+; CHECK-NEXT: IR %exitcond.not = icmp eq i64 %iv.next, 1024
+; CHECK-NEXT: No successors
+; CH...
[truncated]
|
| O << ")"; | ||
| break; | ||
| } | ||
| case ExpressionTypes::ExtNegatedMulAccReduction: { |
There was a problem hiding this comment.
Is there a way to commonise this with the ExtMulAccReduction case if the only difference is a negate?
There was a problem hiding this comment.
That was my initial approach but it required checking the number of operands to know if there was a sub or not, and I was asked to create an expression type to not rely on operand ordering being stable.
There was a problem hiding this comment.
I think you still could re-use the code for printing, by just checking the expression type to decide whether to print the sub or not.
There was a problem hiding this comment.
Not sure if you've seen the comment above?
There was a problem hiding this comment.
I did miss this, thanks. It should be irrelevant now that that expression type is gone.
| CostAfterReduction = | ||
| TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind); | ||
| CostAfterReduction = TTI.getMulAccReductionCost(IsUnsigned, II.getType(), | ||
| ExtType, false, CostKind); |
There was a problem hiding this comment.
nit: Probably better written as /*Negated=*/false
| /// reduction. | ||
| LLVM_ABI InstructionCost getMulAccReductionCost( | ||
| bool IsUnsigned, Type *ResTy, VectorType *Ty, | ||
| bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated, |
There was a problem hiding this comment.
Is it worth keeping the booleans together, i.e. next to IsUnsigned?
| getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty, | ||
| bool Negated, | ||
| TTI::TargetCostKind CostKind) const override { | ||
| if (Negated) |
There was a problem hiding this comment.
Why can't we add a cost for this?
There was a problem hiding this comment.
Thanks, I've added a cost for the sub.
|
|
||
| InstructionCost RedCost = TTI.getMulAccReductionCost( | ||
| IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind); | ||
| IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind); |
There was a problem hiding this comment.
nit: /*Negated=*/false and same for other below.
| /// Represent an inloop multiply-accumulate reduction, multiplying the | ||
| /// extended vector operands, negating the multiplication, performing a | ||
| /// reduction.add | ||
| /// on the result, and adding |
There was a problem hiding this comment.
Formatting of the comment looks a bit odd - can you fix it?
This PR bundles partial reductions inside the VPExpressionRecipe class. Depends on llvm#147255 .
39f3dab to
0b93c24
Compare
b7c9820 to
fa30b51
Compare
|
Ping |
This PR allows the loop vectorizer to handle in-loop sub reductions by forming a normal in-loop add reduction with a negated input. Stacked PRs: 1. -> llvm/llvm-project#147026 2. llvm/llvm-project#147255 3. llvm/llvm-project#147302 4. llvm/llvm-project#147513
There was a problem hiding this comment.
nit: The wording makes it seem like the optional negation only applies to the second form.
fa30b51 to
1d7cb25
Compare
There was a problem hiding this comment.
| for.body: ; preds = %for.body, %entry | |
| loop: |
nit: consistency with other functions in file
There was a problem hiding this comment.
| for.exit: ; preds = %for.body | |
| exit: |
nit: consistency with other functions in file
There was a problem hiding this comment.
Can you also add a test that checks the generated code? IIUC there should also be changes in costing/vectorization factors we chose, right?
There was a problem hiding this comment.
Ah, the codegen changes will be covered by the existing tests, just curious if it would be possible to add a test that benefits from the cost changes?
There was a problem hiding this comment.
I've tried but haven't been able to come up with a test that is different in the VF chosen without these changes. I reckon that it will be easier once perhaps the AArch64 or ARM getMulAccReductionCost functions accept the sub version.
| O << ")"; | ||
| break; | ||
| } | ||
| case ExpressionTypes::ExtNegatedMulAccReduction: { |
There was a problem hiding this comment.
Not sure if you've seen the comment above?
|
(not sure why, but it looks like the precommit tests on Linux/Windows have not been triggered for some reason, but the libx++ ones have; may be solve by updating the latest main again?) |
1d7cb25 to
229331e
Compare
|
Thanks for the review @fhahn , I'll have a look at it tomorrow. I've just rebased the patch on top of main after the sub reduction patch was merged. This involved removing the negated expression type since (non-chained) a sub reduction is now represented with a subtraction, rather than an addition with a negated input. |
There was a problem hiding this comment.
nit:
| /// ResTy vecreduce.add/sub(mul (A, B)). | |
| /// ResTy vecreduce.add/sub(mul(A, B)). |
There was a problem hiding this comment.
IsNegated no longer exists.
please also add an assert that RedOpcode is either an add or a sub.
There was a problem hiding this comment.
Why is this returning an invalid cost, rather than adding the cost of a negation of the operand?
There was a problem hiding this comment.
I was going with the most conservative approach at first, but I've now allowed subs here (but made sure that it's an add in the UDOT case below). We don't need to consider a negation of the operand since this function isn't used for the chained add+sub case at the moment.
There was a problem hiding this comment.
This looks like a partially NFC change, and I'd prefer the use of VecOp as it was before this change. The reason for this is that in the case that VecOp is not a multiply, Mul is still defined (to the value of VecOp), which I don't think is right.
There was a problem hiding this comment.
Agreed, I think this is left over from when we were checking for a negation. Done.
|
It looks like one of the vectorizer tests is failing in precommit and possibly needs updating? |
This PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account.
7d40358 to
33236a3
Compare
Yeah it needed a rebase, done. |
fhahn
left a comment
There was a problem hiding this comment.
LGTM with a few inline comments remaining, thanks!
| /// ResTy vecreduce.add/sub(mul (A, B)). | ||
| /// ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)). |
There was a problem hiding this comment.
| /// ResTy vecreduce.add/sub(mul (A, B)). | |
| /// ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)). | |
| /// * ResTy vecreduce.add/sub(mul (A, B)) or, | |
| /// * ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)). |
| VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt, | ||
| unsigned Opcode) -> bool { |
There was a problem hiding this comment.
| VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt, | |
| unsigned Opcode) -> bool { | |
| VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt | |
| ) -> bool { |
Can we just use the captured Opcode?
| CostAfterReduction = TTI.getMulAccReductionCost( | ||
| IsUnsigned, ReductionOpc, II.getType(), ExtType, CostKind); |
There was a problem hiding this comment.
it would be nice to have a test for this, but not sure if that's possible.
There was a problem hiding this comment.
I've been trying to make a test but I don't think this code is ever reached. The RedOp && match(RedOp, m_ZExtOrSExt(m_Value())) check above fully (AFAIK) encompasses this check so that code path is always followed instead. If I move this if statement block above that one above then the compiler fails the assertion at Type.cpp:805. This happens on main as well.
Thank you! |
sdesmalen-arm
left a comment
There was a problem hiding this comment.
LGTM with nits addressed.
| auto *RecipeB = | ||
| dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe()); | ||
| auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe()); | ||
| auto *MulR = cast<VPWidenRecipe>(VecOp->getDefiningRecipe()); |
There was a problem hiding this comment.
This rename is NFC, maybe remove it from this PR?
| // Clamp the range if using multiply-accumulate-reduction is profitable. | ||
| auto IsMulAccValidAndClampRange = | ||
| [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, | ||
| [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, |
There was a problem hiding this comment.
this rename is NFC, maybe remove it from this PR?
This PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account. Stacked PRs: 1. llvm/llvm-project#147026 2. -> llvm/llvm-project#147255 3. llvm/llvm-project#147302 4. llvm/llvm-project#147513
This PR bundles partial reductions inside the VPExpressionRecipe class. Depends on llvm#147255 .
This PR bundles partial reductions inside the VPExpressionRecipe class. Depends on llvm#147255 .
This PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account.
Stacked PRs: