[mlir][spirv] Add folding for [S|U]Mod, [S|U]Div, SRem#73341
Merged
Conversation
Add missing constant propogation folder for [S|U]Mod, [S|U]Div, SRem Implement additional folding when rhs is 1 for all ops. This helps for readability of lowered code into SPIR-V. Part of work for llvm#70704
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Finn Plummer (inbelic) ChangesAdd missing constant propogation folder for [S|U]Mod, [S|U]Div, SRem Implement additional folding when rhs is 1 for all ops. This helps for readability of lowered code into SPIR-V. Part of work for #70704 Full diff: https://github.com/llvm/llvm-project/pull/73341.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..16bf173cb7971e0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -534,6 +534,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -573,6 +575,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -673,6 +677,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -707,6 +713,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -811,6 +819,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
```
}];
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..8144a100dab3495 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -69,6 +69,14 @@ static Attribute extractCompositeElement(Attribute composite,
return {};
}
+static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
+ bool div0 = b.isZero();
+
+ bool overflow = a.isMinSignedValue() && b.isAllOnes();
+
+ return div0 || overflow;
+}
+
//===----------------------------------------------------------------------===//
// TableGen'erated canonicalizers
//===----------------------------------------------------------------------===//
@@ -290,6 +298,158 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
[](APInt a, const APInt &b) { return std::move(a) - b; });
}
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
+ // sdiv (x, 1) = x
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer division of Operand 1 divided by Operand 2.
+ // Results are computed per component. Behavior is undefined if Operand 2 is
+ // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
+ // representable value for the operands' type, causing signed overflow.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ return a.sdiv(b);
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
+ // smod (x, 1) = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getZeroAttr(getType());
+
+ // According to SPIR-V spec:
+ //
+ // Signed remainder operation for the remainder whose sign matches the sign
+ // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
+ // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+ // value for the operands' type, causing signed overflow. Otherwise, the
+ // result is the remainder r of Operand 1 divided by Operand 2 where if
+ // r ≠ 0, the sign of r is the same as the sign of Operand 2.
+ //
+ // So don't fold during undefined behaviour
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ APInt c = a.abs().urem(b.abs());
+ if (c.isZero())
+ return c;
+ if (b.isNegative()) {
+ APInt zero = APInt::getZero(c.getBitWidth());
+ return a.isNegative() ? (zero - c) : (b + c);
+ }
+ return a.isNegative() ? (b - c) : c;
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
+ // x % 1 = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getZeroAttr(getType());
+
+ // According to SPIR-V spec:
+ //
+ // Signed remainder operation for the remainder whose sign matches the sign
+ // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
+ // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+ // value for the operands' type, causing signed overflow. Otherwise, the
+ // result is the remainder r of Operand 1 divided by Operand 2 where if
+ // r ≠ 0, the sign of r is the same as the sign of Operand 1.
+
+ // Don't fold if it would do undefined behaviour.
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](APInt a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ return a.srem(b);
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
+ // udiv (x, 1) = x
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
+ // undefined if Operand 2 is 0.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0 = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0 || b.isZero()) {
+ div0 = true;
+ return a;
+ }
+ return a.udiv(b);
+ });
+ return div0 ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
+ // umod (x, 1) = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getZeroAttr(getType());
+
+ // According to the SPIR-V spec:
+ //
+ // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
+ // undefined if Operand 2 is 0.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0 = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0 || b.isZero()) {
+ div0 = true;
+ return a;
+ }
+ return a.urem(b);
+ });
+ return div0 ? Attribute() : res;
+}
+
//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..7b1163601e1b427 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -462,10 +462,272 @@ func.func @const_fold_vector_isub() -> vector<3xi32> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sdiv_x_1
+func.func @sdiv_x_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %c1 = spirv.Constant 1 : i32
+ %2 = spirv.SDiv %arg0, %c1: i32
+ return %2 : i32
+}
+
+// CHECK-LABEL: @sdiv_div_0_or_overflow
+func.func @sdiv_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SDiv
+ // CHECK: spirv.SDiv
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SDiv %cn1, %c0 : i32
+ %1 = spirv.SDiv %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_sdiv
+func.func @const_fold_scalar_sdiv() -> (i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: spirv.Constant -18
+ // CHECK-DAG: spirv.Constant -2
+ // CHECK-DAG: spirv.Constant -7
+ // CHECK-DAG: spirv.Constant 8
+ %0 = spirv.SDiv %c56, %c7 : i32
+ %1 = spirv.SDiv %c56, %cn8 : i32
+ %2 = spirv.SDiv %cn8, %c3 : i32
+ %3 = spirv.SDiv %c56, %cn3 : i32
+ return %0, %1, %2, %3: i32, i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_sdiv
+func.func @const_fold_vector_sdiv() -> vector<3xi32> {
+ // CHECK: spirv.Constant dense<[0, -1, -3]>
+
+ %cv_num = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+ %cv_denom = spirv.Constant dense<[76, -24, 5]> : vector<3xi32>
+ %0 = spirv.SDiv %cv_num, %cv_denom : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smod_x_1
+func.func @smod_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant dense<0>
+ %c1 = spirv.Constant 1 : i32
+ %cv1 = spirv.Constant dense<1> : vector<3xi32>
+ %0 = spirv.SMod %arg0, %c1: i32
+ %1 = spirv.SMod %arg1, %cv1: vector<3xi32>
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @smod_div_0_or_overflow
+func.func @smod_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SMod
+ // CHECK: spirv.SMod
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SMod %cn1, %c0 : i32
+ %1 = spirv.SMod %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_smod
+func.func @const_fold_scalar_smod() -> (i32, i32, i32, i32, i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %cn56 = spirv.Constant -56 : i32
+ %c59 = spirv.Constant 59 : i32
+ %cn59 = spirv.Constant -59 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+ // CHECK-DAG: %[[FIFTYTHREE:.*]] = spirv.Constant 53 : i32
+ // CHECK-DAG: %[[NFIFTYTHREE:.*]] = spirv.Constant -53 : i32
+ // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ %0 = spirv.SMod %c56, %c7 : i32
+ %1 = spirv.SMod %c56, %cn8 : i32
+ %2 = spirv.SMod %c56, %c3 : i32
+ %3 = spirv.SMod %cn3, %c56 : i32
+ %4 = spirv.SMod %cn3, %cn56 : i32
+ %5 = spirv.SMod %c59, %c56 : i32
+ %6 = spirv.SMod %c59, %cn56 : i32
+ %7 = spirv.SMod %cn59, %cn56 : i32
+
+ // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[FIFTYTHREE]], %[[NTHREE]], %[[THREE]], %[[NFIFTYTHREE]], %[[NTHREE]]
+ return %0, %1, %2, %3, %4, %5, %6, %7 : i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_smod
+func.func @const_fold_vector_smod() -> vector<3xi32> {
+ // CHECK: spirv.Constant dense<[42, -4, 4]>
+
+ %cv = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+ %cv_mod = spirv.Constant dense<[76, -7, 5]> : vector<3xi32>
+ %0 = spirv.SMod %cv, %cv_mod : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @srem_x_1
+func.func @srem_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant dense<0>
+ %c1 = spirv.Constant 1 : i32
+ %cv1 = spirv.Constant dense<1> : vector<3xi32>
+ %0 = spirv.SRem %arg0, %c1: i32
+ %1 = spirv.SRem %arg1, %cv1: vector<3xi32>
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @srem_div_0_or_overflow
+func.func @srem_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SRem
+ // CHECK: spirv.SRem
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SRem %cn1, %c0 : i32
+ %1 = spirv.SRem %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_srem
+func.func @const_fold_scalar_srem() -> (i32, i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[ONE:.*]] = spirv.Constant 1 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ %0 = spirv.SRem %c56, %c7 : i32
+ %1 = spirv.SRem %c56, %cn8 : i32
+ %2 = spirv.SRem %c56, %c3 : i32
+ %3 = spirv.SRem %cn3, %c56 : i32
+ %4 = spirv.SRem %c7, %cn3 : i32
+ // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[NTHREE]], %[[ONE]]
+ return %0, %1, %2, %3, %4 : i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @udiv_x_1
+func.func @udiv_x_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %c1 = spirv.Constant 1 : i32
+ %2 = spirv.UDiv %arg0, %c1: i32
+ return %2 : i32
+}
+
+// CHECK-LABEL: @udiv_div_0
+func.func @udiv_div_0() -> i32 {
+ // CHECK: spirv.UDiv
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %0 = spirv.UDiv %cn1, %c0 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_udiv
+func.func @const_fold_scalar_udiv() -> (i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant 1431655762
+ // CHECK-DAG: spirv.Constant 8
+ %0 = spirv.UDiv %c56, %c7 : i32
+ %1 = spirv.UDiv %cn8, %c3 : i32
+ %2 = spirv.UDiv %c56, %cn8 : i32
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @umod_x_1
+func.func @umod_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant dense<0>
+ %c1 = spirv.Constant 1 : i32
+ %cv1 = spirv.Constant dense<1> : vector<3xi32>
+ %0 = spirv.UMod %arg0, %c1: i32
+ %1 = spirv.UMod %arg1, %cv1: vector<3xi32>
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @umod_div_0
+func.func @umod_div_0() -> i32 {
+ // CHECK: spirv.UMod
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %0 = spirv.UMod %cn1, %c0 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_umod
+func.func @const_fold_scalar_umod() -> (i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant 2
+ // CHECK-DAG: spirv.Constant 56
+ %0 = spirv.UMod %c56, %c7 : i32
+ %1 = spirv.UMod %cn8, %c3 : i32
+ %2 = spirv.UMod %c56, %cn8 : i32
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_umod
+func.func @const_fold_vector_umod() -> vector<3xi32> {
+ // CHECK: spirv.Constant dense<[42, 24, 0]>
+
+ %cv = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+ %cv_mod = spirv.Constant dense<[76, -7, 5]> : vector<3xi32>
+ %0 = spirv.UMod %cv, %cv_mod : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
// CHECK-LABEL: @umod_fold
// CHECK-SAME: (%[[ARG:.*]]: i32)
func.func @umod_fold(%arg0: i32) -> (i32, i32) {
|
kuhar
reviewed
Nov 24, 2023
- fix spacing issue - correct spelling - make testcases more strict when matching the return values to ensure proper order
kuhar
approved these changes
Nov 29, 2023
kuhar
left a comment
Member
There was a problem hiding this comment.
LGTM, thanks for the contribution!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add missing constant propogation folder for [S|U]Mod, [S|U]Div, SRem
Implement additional folding when rhs is 1 for all ops.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704