diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index 11c4a29718e1d..caca2ff81964f 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -253,9 +253,6 @@ def CmpIExtUI : // SelectOp //===----------------------------------------------------------------------===// -def GetScalarOrVectorTrueAttribute : - NativeCodeCall<"cast(getBoolAttribute($0.getType(), true))">; - // select(not(pred), a, b) => select(pred, b, a) def SelectNotCond : Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b), @@ -272,31 +269,12 @@ def RedundantSelectFalse : Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)), (SelectOp $pred, $a, $c)>; -// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y) -def SelectAndCond : - Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y), - (SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>; - -// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y) -def SelectAndNotCond : - Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y), - (SelectOp (Arith_AndIOp $predA, - (Arith_XOrIOp $predB, - (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))), - $x, $y)>; - -// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y) -def SelectOrCond : - Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)), - (SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>; - -// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y) -def SelectOrNotCond : - Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)), - (SelectOp (Arith_OrIOp $predA, - (Arith_XOrIOp $predB, - (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))), - $x, $y)>; +// select(pred, false, true) => not(pred) +def SelectI1ToNot : + Pat<(SelectOp $pred, + (ConstantLikeMatcher ConstantAttr), + (ConstantLikeMatcher ConstantAttr)), + (Arith_XOrIOp $pred, (Arith_ConstantOp ConstantAttr))>; //===----------------------------------------------------------------------===// // IndexCastOp diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 0f71c19c23b65..9f64a07f31e3a 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -969,7 +969,6 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) { [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); } - //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// @@ -2173,35 +2172,6 @@ void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // SelectOp //===----------------------------------------------------------------------===// -// Transforms a select of a boolean to arithmetic operations -// -// arith.select %arg, %x, %y : i1 -// -// becomes -// -// and(%arg, %x) or and(!%arg, %y) -struct SelectI1Simplify : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(arith::SelectOp op, - PatternRewriter &rewriter) const override { - if (!op.getType().isInteger(1)) - return failure(); - - Value falseConstant = - rewriter.create(op.getLoc(), true, 1); - Value notCondition = rewriter.create( - op.getLoc(), op.getCondition(), falseConstant); - - Value trueVal = rewriter.create( - op.getLoc(), op.getCondition(), op.getTrueValue()); - Value falseVal = rewriter.create(op.getLoc(), notCondition, - op.getFalseValue()); - rewriter.replaceOpWithNewOp(op, trueVal, falseVal); - return success(); - } -}; - // select %arg, %c1, %c0 => extui %arg struct SelectToExtUI : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -2238,9 +2208,8 @@ struct SelectToExtUI : public OpRewritePattern { void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index cb98a10048a30..bdc6c91d92677 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -116,18 +116,6 @@ func.func @selToNot(%arg0: i1) -> i1 { return %res : i1 } -// CHECK-LABEL: @selToArith -// CHECK-NEXT: %[[trueval:.+]] = arith.constant true -// CHECK-NEXT: %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1 -// CHECK-NEXT: %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1 -// CHECK-NEXT: %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1 -// CHECK-NEXT: %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1 -// CHECK: return %[[res]] -func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 { - %res = arith.select %arg0, %arg1, %arg2 : i1 - return %res : i1 -} - // CHECK-LABEL: @redundantSelectTrue // CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3 // CHECK-NEXT: return %[[res]] @@ -160,74 +148,6 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : return %res1, %res2 : i32, i32 } -// CHECK-LABEL: @selAndCond -// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0 -// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3 -// CHECK-NEXT: return %[[res]] -func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 { - %sel = arith.select %arg0, %arg2, %arg3 : i32 - %res = arith.select %arg1, %sel, %arg3 : i32 - return %res : i32 -} - -// CHECK-LABEL: @selAndNotCond -// CHECK-NEXT: %[[one:.+]] = arith.constant true -// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]] -// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]] -// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2 -// CHECK-NEXT: return %[[res]] -func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 { - %sel = arith.select %arg0, %arg2, %arg3 : i32 - %res = arith.select %arg1, %sel, %arg2 : i32 - return %res : i32 -} - -// CHECK-LABEL: @selAndNotCondVec -// CHECK-NEXT: %[[one:.+]] = arith.constant dense : vector<4xi1> -// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]] -// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]] -// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2 -// CHECK-NEXT: return %[[res]] -func.func @selAndNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> { - %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32> - %res = arith.select %arg1, %sel, %arg2 : vector<4xi1>, vector<4xi32> - return %res : vector<4xi32> -} - -// CHECK-LABEL: @selOrCond -// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0 -// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3 -// CHECK-NEXT: return %[[res]] -func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 { - %sel = arith.select %arg0, %arg2, %arg3 : i32 - %res = arith.select %arg1, %arg2, %sel : i32 - return %res : i32 -} - -// CHECK-LABEL: @selOrNotCond -// CHECK-NEXT: %[[one:.+]] = arith.constant true -// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]] -// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]] -// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2 -// CHECK-NEXT: return %[[res]] -func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 { - %sel = arith.select %arg0, %arg2, %arg3 : i32 - %res = arith.select %arg1, %arg3, %sel : i32 - return %res : i32 -} - -// CHECK-LABEL: @selOrNotCondVec -// CHECK-NEXT: %[[one:.+]] = arith.constant dense : vector<4xi1> -// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]] -// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]] -// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2 -// CHECK-NEXT: return %[[res]] -func.func @selOrNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> { - %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32> - %res = arith.select %arg1, %arg3, %sel : vector<4xi1>, vector<4xi32> - return %res : vector<4xi32> -} - // Test case: Folding of comparisons with equal operands. // CHECK-LABEL: @cmpi_equal_operands // CHECK-DAG: %[[T:.*]] = arith.constant true