[mlir][arith] Fix arith.select canonicalization patterns#84685
Conversation
Because `arith.select` does not propagate poison of the second or third operand depending on the condition, some canonicalization patterns were incorrect. This patch removes these incorrect patterns, and adds a new pattern to fix the case of `i1` select with constants. Patterns that are removed: * select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y) * select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y) * select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y) * select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y) * arith.select %arg, %x, %y : i1 => and(%arg, %x) or and(!%arg, %y) The first two patterns are incorrect when `predB` is poison and `predA` is false, as a non-poison `y` gets compiled to `poison`. The next two patterns are incorrect when `predB` is poison and `predA` is true, as a non-poison `x` gets compiled to `poison`. The last pattern is incorrect as it propagates poison from all operands afer compilation.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Fehr Mathieu (math-fehr) ChangesBecause Patterns that are removed:
Pattern that is added:
The first two patterns are incorrect when Full diff: https://github.com/llvm/llvm-project/pull/84685.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 11c4a29718e1d9..caca2ff81964f7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -253,9 +253,6 @@ def CmpIExtUI :
// SelectOp
//===----------------------------------------------------------------------===//
-def GetScalarOrVectorTrueAttribute :
- NativeCodeCall<"cast<TypedAttr>(getBoolAttribute($0.getType(), true))">;
-
// select(not(pred), a, b) => select(pred, b, a)
def SelectNotCond :
Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
@@ -272,31 +269,12 @@ def RedundantSelectFalse :
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
(SelectOp $pred, $a, $c)>;
-// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
-def SelectAndCond :
- Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
- (SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
-
-// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
-def SelectAndNotCond :
- Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
- (SelectOp (Arith_AndIOp $predA,
- (Arith_XOrIOp $predB,
- (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
- $x, $y)>;
-
-// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
-def SelectOrCond :
- Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
- (SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
-
-// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
-def SelectOrNotCond :
- Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
- (SelectOp (Arith_OrIOp $predA,
- (Arith_XOrIOp $predB,
- (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
- $x, $y)>;
+// select(pred, false, true) => not(pred)
+def SelectI1ToNot :
+ Pat<(SelectOp $pred,
+ (ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
+ (ConstantLikeMatcher ConstantAttr<I1Attr, "1">)),
+ (Arith_XOrIOp $pred, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))>;
//===----------------------------------------------------------------------===//
// IndexCastOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0f71c19c23b654..9f64a07f31e3af 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -969,7 +969,6 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
}
-
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
@@ -2173,35 +2172,6 @@ void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// SelectOp
//===----------------------------------------------------------------------===//
-// Transforms a select of a boolean to arithmetic operations
-//
-// arith.select %arg, %x, %y : i1
-//
-// becomes
-//
-// and(%arg, %x) or and(!%arg, %y)
-struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
- using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arith::SelectOp op,
- PatternRewriter &rewriter) const override {
- if (!op.getType().isInteger(1))
- return failure();
-
- Value falseConstant =
- rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
- Value notCondition = rewriter.create<arith::XOrIOp>(
- op.getLoc(), op.getCondition(), falseConstant);
-
- Value trueVal = rewriter.create<arith::AndIOp>(
- op.getLoc(), op.getCondition(), op.getTrueValue());
- Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
- op.getFalseValue());
- rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
- return success();
- }
-};
-
// select %arg, %c1, %c0 => extui %arg
struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
@@ -2238,9 +2208,8 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
- SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
- SelectNotCond, SelectToExtUI>(context);
+ results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
+ SelectI1ToNot, SelectToExtUI>(context);
}
OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index cb98a10048a309..bdc6c91d926775 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -116,18 +116,6 @@ func.func @selToNot(%arg0: i1) -> i1 {
return %res : i1
}
-// CHECK-LABEL: @selToArith
-// CHECK-NEXT: %[[trueval:.+]] = arith.constant true
-// CHECK-NEXT: %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1
-// CHECK-NEXT: %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1
-// CHECK-NEXT: %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1
-// CHECK-NEXT: %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1
-// CHECK: return %[[res]]
-func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
- %res = arith.select %arg0, %arg1, %arg2 : i1
- return %res : i1
-}
-
// CHECK-LABEL: @redundantSelectTrue
// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
// CHECK-NEXT: return %[[res]]
@@ -160,74 +148,6 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
return %res1, %res2 : i32, i32
}
-// CHECK-LABEL: @selAndCond
-// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
-// CHECK-NEXT: return %[[res]]
-func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
- %sel = arith.select %arg0, %arg2, %arg3 : i32
- %res = arith.select %arg1, %sel, %arg3 : i32
- return %res : i32
-}
-
-// CHECK-LABEL: @selAndNotCond
-// CHECK-NEXT: %[[one:.+]] = arith.constant true
-// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
-// CHECK-NEXT: return %[[res]]
-func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
- %sel = arith.select %arg0, %arg2, %arg3 : i32
- %res = arith.select %arg1, %sel, %arg2 : i32
- return %res : i32
-}
-
-// CHECK-LABEL: @selAndNotCondVec
-// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
-// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
-// CHECK-NEXT: return %[[res]]
-func.func @selAndNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
- %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
- %res = arith.select %arg1, %sel, %arg2 : vector<4xi1>, vector<4xi32>
- return %res : vector<4xi32>
-}
-
-// CHECK-LABEL: @selOrCond
-// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
-// CHECK-NEXT: return %[[res]]
-func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
- %sel = arith.select %arg0, %arg2, %arg3 : i32
- %res = arith.select %arg1, %arg2, %sel : i32
- return %res : i32
-}
-
-// CHECK-LABEL: @selOrNotCond
-// CHECK-NEXT: %[[one:.+]] = arith.constant true
-// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
-// CHECK-NEXT: return %[[res]]
-func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
- %sel = arith.select %arg0, %arg2, %arg3 : i32
- %res = arith.select %arg1, %arg3, %sel : i32
- return %res : i32
-}
-
-// CHECK-LABEL: @selOrNotCondVec
-// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
-// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
-// CHECK-NEXT: return %[[res]]
-func.func @selOrNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
- %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
- %res = arith.select %arg1, %arg3, %sel : vector<4xi1>, vector<4xi32>
- return %res : vector<4xi32>
-}
-
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = arith.constant true
|
kuhar
left a comment
There was a problem hiding this comment.
Neat, thanks for taking care of these.
Because
arith.selectdoes not propagate poison of the second or third operand depending on the condition, some canonicalization patterns are currently incorrect. This patch removes these incorrect patterns, and adds a new pattern to fix the case ofi1select with constants.Patterns that are removed:
Pattern that is added:
The first two patterns are incorrect when
predBis poison andpredAis false, as a non-poisonygets compiled topoison. The next two patterns are incorrect whenpredBis poison andpredAis true, as a non-poisonxgets compiled topoison. The last pattern is incorrect as it propagates poison from all operands afer compilation.