diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td index 286f4de6f90f6..e19bd640075c1 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td @@ -457,6 +457,8 @@ def SPIRV_ShiftLeftLogicalOp : SPIRV_ShiftOp<"ShiftLeftLogical", %5 = spirv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16> ``` }]; + + let hasFolder = 1; } // ----- @@ -499,6 +501,8 @@ def SPIRV_ShiftRightArithmeticOp : SPIRV_ShiftOp<"ShiftRightArithmetic", %5 = spirv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16> ``` }]; + + let hasFolder = 1; } // ----- @@ -542,6 +546,8 @@ def SPIRV_ShiftRightLogicalOp : SPIRV_ShiftOp<"ShiftRightLogical", %5 = spirv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16> ``` }]; + + let hasFolder = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 9acd982dc95af..528d6a5d483aa 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -356,6 +356,108 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) { return Attribute(); } +//===----------------------------------------------------------------------===// +// spirv.ShiftLeftLogical +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::ShiftLeftLogicalOp::fold( + spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) { + // x << 0 -> x + if (matchPattern(adaptor.getOperand2(), m_Zero())) { + return getOperand1(); + } + + // Unfortunately due to below undefined behaviour can't fold 0 for Base. + + // According to the SPIR-V spec: + // + // Type is a scalar or vector of integer type. + // Results are computed per component, and within each component, per bit... + // + // The result is undefined if Shift is greater than or equal to the bit width + // of the components of Base. + // + // So we can use the APInt << method, but don't fold if undefined behaviour. + bool shiftToLarge = false; + auto res = constFoldBinaryOp( + adaptor.getOperands(), [&](const APInt &a, const APInt &b) { + if (shiftToLarge || b.uge(a.getBitWidth())) { + shiftToLarge = true; + return a; + } + return a << b; + }); + return shiftToLarge ? Attribute() : res; +} + +//===----------------------------------------------------------------------===// +// spirv.ShiftRightArithmetic +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::ShiftRightArithmeticOp::fold( + spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) { + // x >> 0 -> x + if (matchPattern(adaptor.getOperand2(), m_Zero())) { + return getOperand1(); + } + + // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base. + + // According to the SPIR-V spec: + // + // Type is a scalar or vector of integer type. + // Results are computed per component, and within each component, per bit... + // + // The result is undefined if Shift is greater than or equal to the bit width + // of the components of Base. + // + // So we can use the APInt ashr method, but don't fold if undefined behaviour. + bool shiftToLarge = false; + auto res = constFoldBinaryOp( + adaptor.getOperands(), [&](const APInt &a, const APInt &b) { + if (shiftToLarge || b.uge(a.getBitWidth())) { + shiftToLarge = true; + return a; + } + return a.ashr(b); + }); + return shiftToLarge ? Attribute() : res; +} + +//===----------------------------------------------------------------------===// +// spirv.ShiftRightLogical +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::ShiftRightLogicalOp::fold( + spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) { + // x >> 0 -> x + if (matchPattern(adaptor.getOperand2(), m_Zero())) { + return getOperand1(); + } + + // Unfortunately due to below undefined behaviour can't fold 0 for Base. + + // According to the SPIR-V spec: + // + // Type is a scalar or vector of integer type. + // Results are computed per component, and within each component, per bit... + // + // The result is undefined if Shift is greater than or equal to the bit width + // of the components of Base. + // + // So we can use the APInt lshr method, but don't fold if undefined behaviour. + bool shiftToLarge = false; + auto res = constFoldBinaryOp( + adaptor.getOperands(), [&](const APInt &a, const APInt &b) { + if (shiftToLarge || b.uge(a.getBitWidth())) { + shiftToLarge = true; + return a; + } + return a.lshr(b); + }); + return shiftToLarge ? Attribute() : res; +} + //===----------------------------------------------------------------------===// // spirv.mlir.selection //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 0200805a44439..3919a051fc875 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -660,6 +660,184 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3 // ----- +//===----------------------------------------------------------------------===// +// spirv.LeftShiftLogical +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @lsl_x_0 +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>) +func.func @lsl_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) { + %c0 = spirv.Constant 0 : i32 + %cv0 = spirv.Constant dense<0> : vector<3xi32> + + %0 = spirv.ShiftLeftLogical %arg0, %c0 : i32, i32 + %1 = spirv.ShiftLeftLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32> + + // CHECK: return %[[ARG0]], %[[ARG1]] + return %0, %1 : i32, vector<3xi32> +} + +// CHECK-LABEL: @lsl_shift_overflow +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>) +func.func @lsl_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) { + // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32 + // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]> + %c32 = spirv.Constant 32 : i32 + %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32> + + // CHECK: %0 = spirv.ShiftLeftLogical %[[ARG0]], %[[C32]] + // CHECK: %1 = spirv.ShiftLeftLogical %[[ARG1]], %[[CV]] + %0 = spirv.ShiftLeftLogical %arg0, %c32 : i32, i32 + %1 = spirv.ShiftLeftLogical %arg1, %cv : vector<3xi32>, vector<3xi32> + + return %0, %1 : i32, vector<3xi32> +} + +// CHECK-LABEL: @const_fold_scalar_lsl +func.func @const_fold_scalar_lsl() -> i32 { + %c1 = spirv.Constant 65535 : i32 // 0x0000 ffff + %c2 = spirv.Constant 17 : i32 + + // CHECK: %[[RET:.*]] = spirv.Constant -131072 + // 0x0000 ffff << 17 -> 0xfffe 0000 + %0 = spirv.ShiftLeftLogical %c1, %c2 : i32, i32 + + // CHECK: return %[[RET]] + return %0 : i32 +} + +// CHECK-LABEL: @const_fold_vector_lsl +func.func @const_fold_vector_lsl() -> vector<3xi32> { + %c1 = spirv.Constant dense<[1, -1, 127]> : vector<3xi32> + %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32> + + // CHECK: %[[RET:.*]] = spirv.Constant dense<[-2147483648, -65536, 1040384]> + %0 = spirv.ShiftLeftLogical %c1, %c2 : vector<3xi32>, vector<3xi32> + + // CHECK: return %[[RET]] + return %0 : vector<3xi32> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.RightShiftArithmetic +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @asr_x_0 +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>) +func.func @asr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) { + %c0 = spirv.Constant 0 : i32 + %cv0 = spirv.Constant dense<0> : vector<3xi32> + + %0 = spirv.ShiftRightArithmetic %arg0, %c0 : i32, i32 + %1 = spirv.ShiftRightArithmetic %arg1, %cv0 : vector<3xi32>, vector<3xi32> + + // CHECK: return %[[ARG0]], %[[ARG1]] + return %0, %1 : i32, vector<3xi32> +} + +// CHECK-LABEL: @asr_shift_overflow +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>) +func.func @asr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) { + // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32 + // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]> + %c32 = spirv.Constant 32 : i32 + %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32> + + // CHECK: %0 = spirv.ShiftRightArithmetic %[[ARG0]], %[[C32]] + // CHECK: %1 = spirv.ShiftRightArithmetic %[[ARG1]], %[[CV]] + %0 = spirv.ShiftRightArithmetic %arg0, %c32 : i32, i32 + %1 = spirv.ShiftRightArithmetic %arg1, %cv : vector<3xi32>, vector<3xi32> + + return %0, %1 : i32, vector<3xi32> +} + +// CHECK-LABEL: @const_fold_scalar_asr +func.func @const_fold_scalar_asr() -> i32 { + %c1 = spirv.Constant -131072 : i32 // 0xfffe 0000 + %c2 = spirv.Constant 17 : i32 + // 0x0000 ffff ashr 17 -> 0xffff ffff + // CHECK: %[[RET:.*]] = spirv.Constant -1 + %0 = spirv.ShiftRightArithmetic %c1, %c2 : i32, i32 + + // CHECK: return %[[RET]] + return %0 : i32 +} + +// CHECK-LABEL: @const_fold_vector_asr +func.func @const_fold_vector_asr() -> vector<3xi32> { + %c1 = spirv.Constant dense<[-2147483648, 239847, 127]> : vector<3xi32> + %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32> + + // CHECK: %[[RET:.*]] = spirv.Constant dense<[-1, 3, 0]> + %0 = spirv.ShiftRightArithmetic %c1, %c2 : vector<3xi32>, vector<3xi32> + + // CHECK: return %[[RET]] + return %0 : vector<3xi32> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.RightShiftLogical +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @lsr_x_0 +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>) +func.func @lsr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) { + %c0 = spirv.Constant 0 : i32 + %cv0 = spirv.Constant dense<0> : vector<3xi32> + + %0 = spirv.ShiftRightLogical %arg0, %c0 : i32, i32 + %1 = spirv.ShiftRightLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32> + + // CHECK: return %[[ARG0]], %[[ARG1]] + return %0, %1 : i32, vector<3xi32> +} + +// CHECK-LABEL: @lsr_shift_overflow +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>) +func.func @lsr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) { + // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32 + // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]> + %c32 = spirv.Constant 32 : i32 + %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32> + + // CHECK: %0 = spirv.ShiftRightLogical %[[ARG0]], %[[C32]] + // CHECK: %1 = spirv.ShiftRightLogical %[[ARG1]], %[[CV]] + %0 = spirv.ShiftRightLogical %arg0, %c32 : i32, i32 + %1 = spirv.ShiftRightLogical %arg1, %cv : vector<3xi32>, vector<3xi32> + return %0, %1 : i32, vector<3xi32> +} + +// CHECK-LABEL: @const_fold_scalar_lsr +func.func @const_fold_scalar_lsr() -> i32 { + %c1 = spirv.Constant -131072 : i32 // 0xfffe 0000 + %c2 = spirv.Constant 17 : i32 + + // 0x0000 ffff << 17 -> 0x0000 7fff + // CHECK: %[[RET:.*]] = spirv.Constant 32767 + %0 = spirv.ShiftRightLogical %c1, %c2 : i32, i32 + + // CHECK: return %[[RET]] + return %0 : i32 +} + +// CHECK-LABEL: @const_fold_vector_lsr +func.func @const_fold_vector_lsr() -> vector<3xi32> { + %c1 = spirv.Constant dense<[-2147483648, -1, -127]> : vector<3xi32> + %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32> + + // CHECK: %[[RET:.*]] = spirv.Constant dense<[1, 65535, 524287]> + %0 = spirv.ShiftRightLogical %c1, %c2 : vector<3xi32>, vector<3xi32> + + // CHECK: return %[[RET]] + return %0 : vector<3xi32> +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.mlir.selection //===----------------------------------------------------------------------===//