[InstCombine] Transform vector.reduce.add and splat into multiplication#161020
[InstCombine] Transform vector.reduce.add and splat into multiplication#161020
vector.reduce.add and splat into multiplication#161020Conversation
…32 %0, 2` Fixes llvm#160066 Whenever we have a vector with all the same elemnts, created with `insertelement` and `shufflevector` and the result type's element number is a power of two and we sum the vector, we have a multiplication by a power of two, which can be replaced with a left shift.
|
This should not be limited to powers of two. You can just emit a multiply and it will get folded to a shift in the power of two case. |
|
Thank you very much for your review @nikic . I am really happy that you have suggested to optimize the non power of two cases. It was fun implementig those too. :)
I am also open to any further potential improvement idea for this patch. |
vector.reduce.add (splat %0, 4) into shl i32 %0, 2vector.reduce.add and splat into multiplication
|
There is one lldb failure on Linux. I think that is just a flaky test case, which isn't caused by this PR. I will retrigger the CI. |
|
@llvm/pr-subscribers-llvm-transforms Author: Gábor Spaits (spaits) ChangesFixes #160066 Whenever we have a vector with all the same elemnts, created with Full diff: https://github.com/llvm/llvm-project/pull/161020.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 6ad493772d170..74c263e86f4a4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -64,6 +64,7 @@
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/KnownFPClass.h"
#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/TypeSize.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
@@ -3761,6 +3762,41 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return replaceInstUsesWith(CI, Res);
}
}
+
+ // Handle the case where a value is multiplied by a power of two.
+ // For example:
+ // %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+ // %3 = shufflevector <4 x i32> %2, poison, <4 x i32> zeroinitializer
+ // %4 = tail call i32 @llvm.vector.reduce.add.v4i32(%3)
+ // =>
+ // %2 = shl i32 %0, 2
+ assert(Arg->getType()->isVectorTy() &&
+ "The vector.reduce.add intrinsic's argument must be a vector!");
+
+ if (Value *Splat = getSplatValue(Arg)) {
+ // It is only a multiplication if we add the same element over and over.
+ ElementCount ReducedVectorElementCount =
+ static_cast<VectorType *>(Arg->getType())->getElementCount();
+ if (ReducedVectorElementCount.isFixed()) {
+ unsigned VectorSize = ReducedVectorElementCount.getFixedValue();
+ Type *SplatType = Splat->getType();
+ unsigned SplatTypeWidth = SplatType->getIntegerBitWidth();
+ Value *Res;
+ // Power of two is a special case. We can just use a left shif here.
+ if (isPowerOf2_32(VectorSize)) {
+ unsigned Pow2 = Log2_32(VectorSize);
+ Res = Builder.CreateShl(
+ Splat, Constant::getIntegerValue(SplatType,
+ APInt(SplatTypeWidth, Pow2)));
+ return replaceInstUsesWith(CI, Res);
+ }
+ // Otherwise just multiply.
+ Res = Builder.CreateMul(
+ Splat, Constant::getIntegerValue(
+ SplatType, APInt(SplatTypeWidth, VectorSize)));
+ return replaceInstUsesWith(CI, Res);
+ }
+ }
}
[[fallthrough]];
}
diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll
index 10f4aca72dbc7..e071415d2d6c1 100644
--- a/llvm/test/Transforms/InstCombine/vector-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll
@@ -308,3 +308,93 @@ define i32 @diff_of_sums_type_mismatch2(<8 x i32> %v0, <4 x i32> %v1) {
%r = sub i32 %r0, %r1
ret i32 %r
}
+
+define i32 @constant_multiplied_at_0(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0(
+; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 2
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <4 x i32> zeroinitializer
+ %4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
+ ret i32 %4
+}
+
+define i64 @constant_multiplied_at_0_64bits(i64 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0_64bits(
+; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP0:%.*]], 2
+; CHECK-NEXT: ret i64 [[TMP2]]
+;
+ %2 = insertelement <4 x i64> poison, i64 %0, i64 0
+ %3 = shufflevector <4 x i64> %2, <4 x i64> poison, <4 x i32> zeroinitializer
+ %4 = tail call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> %3)
+ ret i64 %4
+}
+
+define i32 @constant_multiplied_at_0_two_pow8(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0_two_pow8(
+; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 3
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <8 x i32> zeroinitializer
+ %4 = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %3)
+ ret i32 %4
+}
+
+
+define i32 @constant_multiplied_at_0_two_pow16(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0_two_pow16(
+; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 4
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <16 x i32> zeroinitializer
+ %4 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %3)
+ ret i32 %4
+}
+
+
+define i32 @constant_multiplied_at_1(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_1(
+; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 2
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 1
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison,
+ <4 x i32> <i32 1, i32 1, i32 1, i32 1>
+ %4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
+ ret i32 %4
+}
+
+define i32 @negative_constant_multiplied_at_1(i32 %0) {
+; CHECK-LABEL: @negative_constant_multiplied_at_1(
+; CHECK-NEXT: ret i32 poison
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 1
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <4 x i32> zeroinitializer
+ %4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
+ ret i32 %4
+}
+
+define i32 @constant_multiplied_non_power_of_2(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_non_power_of_2(
+; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[TMP0:%.*]], 6
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <6 x i32> zeroinitializer
+ %4 = tail call i32 @llvm.vector.reduce.add.v6i32(<6 x i32> %3)
+ ret i32 %4
+}
+
+define i64 @constant_multiplied_non_power_of_2_i64(i64 %0) {
+; CHECK-LABEL: @constant_multiplied_non_power_of_2_i64(
+; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP0:%.*]], 6
+; CHECK-NEXT: ret i64 [[TMP2]]
+;
+ %2 = insertelement <4 x i64> poison, i64 %0, i64 0
+ %3 = shufflevector <4 x i64> %2, <4 x i64> poison, <6 x i32> zeroinitializer
+ %4 = tail call i64 @llvm.vector.reduce.add.v6i64(<6 x i64> %3)
+ ret i64 %4
+}
|
|
Thank you very much for your review @XChy . I have addressed your comments. |
|
@zyw-bot mfuzz |
| assert(Arg->getType()->isVectorTy() && | ||
| "The vector.reduce.add intrinsic's argument must be a vector!"); | ||
| ElementCount ReducedVectorElementCount = | ||
| static_cast<VectorType *>(Arg->getType())->getElementCount(); |
There was a problem hiding this comment.
| static_cast<VectorType *>(Arg->getType())->getElementCount(); | |
| cast<VectorType>(Arg->getType())->getElementCount(); |
And remove the assert.
| Value *Res = | ||
| Builder.CreateMul(Splat, ConstantInt::get(SplatType, VectorSize)); | ||
| return replaceInstUsesWith(CI, Res); |
There was a problem hiding this comment.
| Value *Res = | |
| Builder.CreateMul(Splat, ConstantInt::get(SplatType, VectorSize)); | |
| return replaceInstUsesWith(CI, Res); | |
| return BinaryOperator::CreateMul(Splat, ConstantInt::get(SplatType, VectorSize)); |
| ; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <2 x i1> [[TMP2]], <2 x i1> poison, <2 x i32> zeroinitializer | ||
| ; CHECK-NEXT: [[TMP4:%.*]] = bitcast <2 x i1> [[TMP3]] to i2 | ||
| ; CHECK-NEXT: [[TMP5:%.*]] = call range(i2 0, -1) i2 @llvm.ctpop.i2(i2 [[TMP4]]) | ||
| ; CHECK-NEXT: [[TMP6:%.*]] = trunc i2 [[TMP5]] to i1 |
There was a problem hiding this comment.
No need for so many i1 tests that don't hit this code path anyway. I'd suggest adding additional i2 tests instead, which make it a bit clearer what is going on (e.g. v5i2 and v6i2).
| ret i2 %4 | ||
| } | ||
|
|
||
| define i2 @constant_multiplied_5xi2(i2 %0) { |
There was a problem hiding this comment.
| ret i2 %4 | ||
| } | ||
|
|
||
| define i2 @constant_multiplied_7xi2(i2 %0) { |
There was a problem hiding this comment.
| ret i2 %4 | ||
| } | ||
|
|
||
| define i2 @constant_multiplied_6xi2(i2 %0) { |
There was a problem hiding this comment.
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/76/builds/13019 Here is the relevant piece of the build log for the reference |
…cation (llvm#161020) Fixes llvm#160066 Whenever we have a vector with all the same elemnts, created with `insertelement` and `shufflevector` and we sum the vector, we have a multiplication.
```llvm define i1 @src(i1 %0) { %2 = insertelement <8 x i1> poison, i1 %0, i32 0 %3 = shufflevector <8 x i1> %2, <8 x i1> poison, <8 x i32> zeroinitializer %4 = tail call i1 @llvm.vector.reduce.add.v8i1(<8 x i1> %3) ret i1 %4 } define i1 @tgt(i1 %0) { ret i1 0 } ``` alive2: https://alive2.llvm.org/ce/z/vejxot `vector_reduce_add(<n x i1>)` to `Trunc(ctpop(bitcast <n x i1> to in))` interferes with the `vector_reduce_add(<splat>)` to `mul`, so I exchanged their order. Relevant PR: #161020
…182213) ```llvm define i1 @src(i1 %0) { %2 = insertelement <8 x i1> poison, i1 %0, i32 0 %3 = shufflevector <8 x i1> %2, <8 x i1> poison, <8 x i32> zeroinitializer %4 = tail call i1 @llvm.vector.reduce.add.v8i1(<8 x i1> %3) ret i1 %4 } define i1 @tgt(i1 %0) { ret i1 0 } ``` alive2: https://alive2.llvm.org/ce/z/vejxot `vector_reduce_add(<n x i1>)` to `Trunc(ctpop(bitcast <n x i1> to in))` interferes with the `vector_reduce_add(<splat>)` to `mul`, so I exchanged their order. Relevant PR: llvm#161020
…182213) ```llvm define i1 @src(i1 %0) { %2 = insertelement <8 x i1> poison, i1 %0, i32 0 %3 = shufflevector <8 x i1> %2, <8 x i1> poison, <8 x i32> zeroinitializer %4 = tail call i1 @llvm.vector.reduce.add.v8i1(<8 x i1> %3) ret i1 %4 } define i1 @tgt(i1 %0) { ret i1 0 } ``` alive2: https://alive2.llvm.org/ce/z/vejxot `vector_reduce_add(<n x i1>)` to `Trunc(ctpop(bitcast <n x i1> to in))` interferes with the `vector_reduce_add(<splat>)` to `mul`, so I exchanged their order. Relevant PR: llvm#161020
Fixes #160066
Whenever we have a vector with all the same elemnts, created with
insertelementandshufflevectorand we sum the vector, we have a multiplication.