diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index c4d1e01f9feef..16bf173cb7971 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -534,6 +534,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv", ``` }]; + + let hasFolder = 1; } // ----- @@ -573,6 +575,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod", ``` }]; + + let hasFolder = 1; } // ----- @@ -673,6 +677,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem", ``` }]; + + let hasFolder = 1; } // ----- @@ -707,6 +713,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv", ``` }]; + + let hasFolder = 1; } // ----- @@ -811,6 +819,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod", ``` }]; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 9acd982dc95af..82af41643edb8 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -69,6 +69,13 @@ static Attribute extractCompositeElement(Attribute composite, return {}; } +static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) { + bool div0 = b.isZero(); + bool overflow = a.isMinSignedValue() && b.isAllOnes(); + + return div0 || overflow; +} + //===----------------------------------------------------------------------===// // TableGen'erated canonicalizers //===----------------------------------------------------------------------===// @@ -290,6 +297,158 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) { [](APInt a, const APInt &b) { return std::move(a) - b; }); } +//===----------------------------------------------------------------------===// +// spirv.SDiv +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) { + // sdiv (x, 1) = x + if (matchPattern(getOperand2(), m_One())) + return getOperand1(); + + // According to the SPIR-V spec: + // + // Signed-integer division of Operand 1 divided by Operand 2. + // Results are computed per component. Behavior is undefined if Operand 2 is + // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum + // representable value for the operands' type, causing signed overflow. + // + // So don't fold during undefined behavior. + bool div0OrOverflow = false; + auto res = constFoldBinaryOp( + adaptor.getOperands(), [&](const APInt &a, const APInt &b) { + if (div0OrOverflow || isDivZeroOrOverflow(a, b)) { + div0OrOverflow = true; + return a; + } + return a.sdiv(b); + }); + return div0OrOverflow ? Attribute() : res; +} + +//===----------------------------------------------------------------------===// +// spirv.SMod +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) { + // smod (x, 1) = 0 + if (matchPattern(getOperand2(), m_One())) + return Builder(getContext()).getZeroAttr(getType()); + + // According to SPIR-V spec: + // + // Signed remainder operation for the remainder whose sign matches the sign + // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is + // undefined if Operand 2 is -1 and Operand 1 is the minimum representable + // value for the operands' type, causing signed overflow. Otherwise, the + // result is the remainder r of Operand 1 divided by Operand 2 where if + // r ≠ 0, the sign of r is the same as the sign of Operand 2. + // + // So don't fold during undefined behavior + bool div0OrOverflow = false; + auto res = constFoldBinaryOp( + adaptor.getOperands(), [&](const APInt &a, const APInt &b) { + if (div0OrOverflow || isDivZeroOrOverflow(a, b)) { + div0OrOverflow = true; + return a; + } + APInt c = a.abs().urem(b.abs()); + if (c.isZero()) + return c; + if (b.isNegative()) { + APInt zero = APInt::getZero(c.getBitWidth()); + return a.isNegative() ? (zero - c) : (b + c); + } + return a.isNegative() ? (b - c) : c; + }); + return div0OrOverflow ? Attribute() : res; +} + +//===----------------------------------------------------------------------===// +// spirv.SRem +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) { + // x % 1 = 0 + if (matchPattern(getOperand2(), m_One())) + return Builder(getContext()).getZeroAttr(getType()); + + // According to SPIR-V spec: + // + // Signed remainder operation for the remainder whose sign matches the sign + // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is + // undefined if Operand 2 is -1 and Operand 1 is the minimum representable + // value for the operands' type, causing signed overflow. Otherwise, the + // result is the remainder r of Operand 1 divided by Operand 2 where if + // r ≠ 0, the sign of r is the same as the sign of Operand 1. + + // Don't fold if it would do undefined behavior. + bool div0OrOverflow = false; + auto res = constFoldBinaryOp( + adaptor.getOperands(), [&](APInt a, const APInt &b) { + if (div0OrOverflow || isDivZeroOrOverflow(a, b)) { + div0OrOverflow = true; + return a; + } + return a.srem(b); + }); + return div0OrOverflow ? Attribute() : res; +} + +//===----------------------------------------------------------------------===// +// spirv.UDiv +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) { + // udiv (x, 1) = x + if (matchPattern(getOperand2(), m_One())) + return getOperand1(); + + // According to the SPIR-V spec: + // + // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is + // undefined if Operand 2 is 0. + // + // So don't fold during undefined behavior. + bool div0 = false; + auto res = constFoldBinaryOp( + adaptor.getOperands(), [&](const APInt &a, const APInt &b) { + if (div0 || b.isZero()) { + div0 = true; + return a; + } + return a.udiv(b); + }); + return div0 ? Attribute() : res; +} + +//===----------------------------------------------------------------------===// +// spirv.UMod +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) { + // umod (x, 1) = 0 + if (matchPattern(getOperand2(), m_One())) + return Builder(getContext()).getZeroAttr(getType()); + + // According to the SPIR-V spec: + // + // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is + // undefined if Operand 2 is 0. + // + // So don't fold during undefined behavior. + bool div0 = false; + auto res = constFoldBinaryOp( + adaptor.getOperands(), [&](const APInt &a, const APInt &b) { + if (div0 || b.isZero()) { + div0 = true; + return a; + } + return a.urem(b); + }); + return div0 ? Attribute() : res; +} + //===----------------------------------------------------------------------===// // spirv.LogicalAnd //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 0200805a44439..6fb5ca5c41839 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -462,10 +462,307 @@ func.func @const_fold_vector_isub() -> vector<3xi32> { // ----- +//===----------------------------------------------------------------------===// +// spirv.SDiv +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @sdiv_x_1 +func.func @sdiv_x_1(%arg0 : i32) -> i32 { + // CHECK-NEXT: return %arg0 : i32 + %c1 = spirv.Constant 1 : i32 + %2 = spirv.SDiv %arg0, %c1: i32 + return %2 : i32 +} + +// CHECK-LABEL: @sdiv_div_0_or_overflow +func.func @sdiv_div_0_or_overflow() -> (i32, i32) { + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1 + // CHECK-DAG: %[[CNMIN:.*]] = spirv.Constant -2147483648 + + %c0 = spirv.Constant 0 : i32 + %cn1 = spirv.Constant -1 : i32 + %min_i32 = spirv.Constant -2147483648 : i32 + + // CHECK: %0 = spirv.SDiv %[[CN1]], %[[C0]] + // CHECK: %1 = spirv.SDiv %[[CNMIN]], %[[CN1]] + %0 = spirv.SDiv %cn1, %c0 : i32 + %1 = spirv.SDiv %min_i32, %cn1 : i32 + return %0, %1 : i32, i32 +} + +// CHECK-LABEL: @const_fold_scalar_sdiv +func.func @const_fold_scalar_sdiv() -> (i32, i32, i32, i32) { + %c56 = spirv.Constant 56 : i32 + %c7 = spirv.Constant 7 : i32 + %cn8 = spirv.Constant -8 : i32 + %c3 = spirv.Constant 3 : i32 + %cn3 = spirv.Constant -3 : i32 + + // CHECK-DAG: %[[CN18:.*]] = spirv.Constant -18 + // CHECK-DAG: %[[CN2:.*]] = spirv.Constant -2 + // CHECK-DAG: %[[CN7:.*]] = spirv.Constant -7 + // CHECK-DAG: %[[C8:.*]] = spirv.Constant 8 + %0 = spirv.SDiv %c56, %c7 : i32 + %1 = spirv.SDiv %c56, %cn8 : i32 + %2 = spirv.SDiv %cn8, %c3 : i32 + %3 = spirv.SDiv %c56, %cn3 : i32 + + // CHECK: return %[[C8]], %[[CN7]], %[[CN2]], %[[CN18]] + return %0, %1, %2, %3: i32, i32, i32, i32 +} + +// CHECK-LABEL: @const_fold_vector_sdiv +func.func @const_fold_vector_sdiv() -> vector<3xi32> { + // CHECK: %[[CVEC:.*]] = spirv.Constant dense<[0, -1, -3]> + + %cv_num = spirv.Constant dense<[42, 24, -16]> : vector<3xi32> + %cv_denom = spirv.Constant dense<[76, -24, 5]> : vector<3xi32> + %0 = spirv.SDiv %cv_num, %cv_denom : vector<3xi32> + + // CHECK: return %[[CVEC]] + return %0 : vector<3xi32> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.SMod +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @smod_x_1 +func.func @smod_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) { + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[CVEC0:.*]] = spirv.Constant dense<0> + %c1 = spirv.Constant 1 : i32 + %cv1 = spirv.Constant dense<1> : vector<3xi32> + %0 = spirv.SMod %arg0, %c1: i32 + %1 = spirv.SMod %arg1, %cv1: vector<3xi32> + + // CHECK: return %[[C0]], %[[CVEC0]] + return %0, %1 : i32, vector<3xi32> +} + +// CHECK-LABEL: @smod_div_0_or_overflow +func.func @smod_div_0_or_overflow() -> (i32, i32) { + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1 + // CHECK-DAG: %[[CNMIN:.*]] = spirv.Constant -2147483648 + + %c0 = spirv.Constant 0 : i32 + %cn1 = spirv.Constant -1 : i32 + %min_i32 = spirv.Constant -2147483648 : i32 + + // CHECK: %0 = spirv.SMod %[[CN1]], %[[C0]] + // CHECK: %1 = spirv.SMod %[[CNMIN]], %[[CN1]] + %0 = spirv.SMod %cn1, %c0 : i32 + %1 = spirv.SMod %min_i32, %cn1 : i32 + return %0, %1 : i32, i32 +} + +// CHECK-LABEL: @const_fold_scalar_smod +func.func @const_fold_scalar_smod() -> (i32, i32, i32, i32, i32, i32, i32, i32) { + %c56 = spirv.Constant 56 : i32 + %cn56 = spirv.Constant -56 : i32 + %c59 = spirv.Constant 59 : i32 + %cn59 = spirv.Constant -59 : i32 + %c7 = spirv.Constant 7 : i32 + %cn8 = spirv.Constant -8 : i32 + %c3 = spirv.Constant 3 : i32 + %cn3 = spirv.Constant -3 : i32 + + // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32 + // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32 + // CHECK-DAG: %[[FIFTYTHREE:.*]] = spirv.Constant 53 : i32 + // CHECK-DAG: %[[NFIFTYTHREE:.*]] = spirv.Constant -53 : i32 + // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32 + // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32 + %0 = spirv.SMod %c56, %c7 : i32 + %1 = spirv.SMod %c56, %cn8 : i32 + %2 = spirv.SMod %c56, %c3 : i32 + %3 = spirv.SMod %cn3, %c56 : i32 + %4 = spirv.SMod %cn3, %cn56 : i32 + %5 = spirv.SMod %c59, %c56 : i32 + %6 = spirv.SMod %c59, %cn56 : i32 + %7 = spirv.SMod %cn59, %cn56 : i32 + + // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[FIFTYTHREE]], %[[NTHREE]], %[[THREE]], %[[NFIFTYTHREE]], %[[NTHREE]] + return %0, %1, %2, %3, %4, %5, %6, %7 : i32, i32, i32, i32, i32, i32, i32, i32 +} + +// CHECK-LABEL: @const_fold_vector_smod +func.func @const_fold_vector_smod() -> vector<3xi32> { + // CHECK: %[[CVEC:.*]] = spirv.Constant dense<[42, -4, 4]> + + %cv = spirv.Constant dense<[42, 24, -16]> : vector<3xi32> + %cv_mod = spirv.Constant dense<[76, -7, 5]> : vector<3xi32> + %0 = spirv.SMod %cv, %cv_mod : vector<3xi32> + + // CHECK: return %[[CVEC]] + return %0 : vector<3xi32> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.SRem +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @srem_x_1 +func.func @srem_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) { + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[CVEC0:.*]] = spirv.Constant dense<0> + %c1 = spirv.Constant 1 : i32 + %cv1 = spirv.Constant dense<1> : vector<3xi32> + %0 = spirv.SRem %arg0, %c1: i32 + %1 = spirv.SRem %arg1, %cv1: vector<3xi32> + + // CHECK: return %[[C0]], %[[CVEC0]] + return %0, %1 : i32, vector<3xi32> +} + +// CHECK-LABEL: @srem_div_0_or_overflow +func.func @srem_div_0_or_overflow() -> (i32, i32) { + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1 + // CHECK-DAG: %[[CNMIN:.*]] = spirv.Constant -2147483648 + %c0 = spirv.Constant 0 : i32 + %cn1 = spirv.Constant -1 : i32 + %min_i32 = spirv.Constant -2147483648 : i32 + + // CHECK: %0 = spirv.SRem %[[CN1]], %[[C0]] + // CHECK: %1 = spirv.SRem %[[CNMIN]], %[[CN1]] + %0 = spirv.SRem %cn1, %c0 : i32 + %1 = spirv.SRem %min_i32, %cn1 : i32 + return %0, %1 : i32, i32 +} + +// CHECK-LABEL: @const_fold_scalar_srem +func.func @const_fold_scalar_srem() -> (i32, i32, i32, i32, i32) { + %c56 = spirv.Constant 56 : i32 + %c7 = spirv.Constant 7 : i32 + %cn8 = spirv.Constant -8 : i32 + %c3 = spirv.Constant 3 : i32 + %cn3 = spirv.Constant -3 : i32 + + // CHECK-DAG: %[[ONE:.*]] = spirv.Constant 1 : i32 + // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32 + // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32 + // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32 + %0 = spirv.SRem %c56, %c7 : i32 + %1 = spirv.SRem %c56, %cn8 : i32 + %2 = spirv.SRem %c56, %c3 : i32 + %3 = spirv.SRem %cn3, %c56 : i32 + %4 = spirv.SRem %c7, %cn3 : i32 + // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[NTHREE]], %[[ONE]] + return %0, %1, %2, %3, %4 : i32, i32, i32, i32, i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.UDiv +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @udiv_x_1 +func.func @udiv_x_1(%arg0 : i32) -> i32 { + // CHECK-NEXT: return %arg0 : i32 + %c1 = spirv.Constant 1 : i32 + %2 = spirv.UDiv %arg0, %c1: i32 + return %2 : i32 +} + +// CHECK-LABEL: @udiv_div_0 +func.func @udiv_div_0() -> i32 { + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1 + %c0 = spirv.Constant 0 : i32 + %cn1 = spirv.Constant -1 : i32 + + // CHECK: %0 = spirv.UDiv %[[CN1]], %[[C0]] + %0 = spirv.UDiv %cn1, %c0 : i32 + return %0 : i32 +} + +// CHECK-LABEL: @const_fold_scalar_udiv +func.func @const_fold_scalar_udiv() -> (i32, i32, i32) { + %c56 = spirv.Constant 56 : i32 + %c7 = spirv.Constant 7 : i32 + %cn8 = spirv.Constant -8 : i32 + %c3 = spirv.Constant 3 : i32 + + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[CBIG:.*]] = spirv.Constant 1431655762 + // CHECK-DAG: %[[C8:.*]] = spirv.Constant 8 + %0 = spirv.UDiv %c56, %c7 : i32 + %1 = spirv.UDiv %cn8, %c3 : i32 + %2 = spirv.UDiv %c56, %cn8 : i32 + + // CHECK: return %[[C8]], %[[CBIG]], %[[C0]] + return %0, %1, %2 : i32, i32, i32 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.UMod //===----------------------------------------------------------------------===// +// CHECK-LABEL: @umod_x_1 +func.func @umod_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) { + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[CVEC0:.*]] = spirv.Constant dense<0> + %c1 = spirv.Constant 1 : i32 + %cv1 = spirv.Constant dense<1> : vector<3xi32> + %0 = spirv.UMod %arg0, %c1: i32 + %1 = spirv.UMod %arg1, %cv1: vector<3xi32> + + // CHECK: return %[[C0]], %[[CVEC0]] + return %0, %1 : i32, vector<3xi32> +} + +// CHECK-LABEL: @umod_div_0 +func.func @umod_div_0() -> i32 { + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1 + %c0 = spirv.Constant 0 : i32 + %cn1 = spirv.Constant -1 : i32 + + // CHECK: %0 = spirv.UMod %[[CN1]], %[[C0]] + %0 = spirv.UMod %cn1, %c0 : i32 + return %0 : i32 +} + +// CHECK-LABEL: @const_fold_scalar_umod +func.func @const_fold_scalar_umod() -> (i32, i32, i32) { + %c56 = spirv.Constant 56 : i32 + %c7 = spirv.Constant 7 : i32 + %cn8 = spirv.Constant -8 : i32 + %c3 = spirv.Constant 3 : i32 + + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[C2:.*]] = spirv.Constant 2 + // CHECK-DAG: %[[C56:.*]] = spirv.Constant 56 + %0 = spirv.UMod %c56, %c7 : i32 + %1 = spirv.UMod %cn8, %c3 : i32 + %2 = spirv.UMod %c56, %cn8 : i32 + + // CHECK: return %[[C0]], %[[C2]], %[[C56]] + return %0, %1, %2 : i32, i32, i32 +} + +// CHECK-LABEL: @const_fold_vector_umod +func.func @const_fold_vector_umod() -> vector<3xi32> { + // CHECK: %[[CVEC:.*]] = spirv.Constant dense<[42, 24, 0]> + + %cv = spirv.Constant dense<[42, 24, -16]> : vector<3xi32> + %cv_mod = spirv.Constant dense<[76, -7, 5]> : vector<3xi32> + %0 = spirv.UMod %cv, %cv_mod : vector<3xi32> + + // CHECK: return %[[CVEC]] + return %0 : vector<3xi32> +} + // CHECK-LABEL: @umod_fold // CHECK-SAME: (%[[ARG:.*]]: i32) func.func @umod_fold(%arg0: i32) -> (i32, i32) {