diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td index e48a56f0625d3..3ee239d6e1e3e 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -800,6 +800,8 @@ def SPIRV_SelectOp : SPIRV_Op<"Select", // These ops require dynamic availability specification based on operand and // result types. bit autogenAvailability = 0; + + let hasFolder = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 4c62289a1e945..ff4bace9a4d88 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -797,6 +797,49 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) { return Attribute(); } +//===----------------------------------------------------------------------===// +// spirv.SelectOp +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) { + // spirv.Select _ x x -> x + Value trueVals = getTrueValue(); + Value falseVals = getFalseValue(); + if (trueVals == falseVals) + return trueVals; + + ArrayRef operands = adaptor.getOperands(); + + // spirv.Select true x y -> x + // spirv.Select false x y -> y + if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0])) + return *boolAttr ? trueVals : falseVals; + + // Check that all the operands are constant + if (!operands[0] || !operands[1] || !operands[2]) + return Attribute(); + + // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in + // the scalar case. Hence, we are only required to consider the case of + // DenseElementsAttr in foldSelectOp. + auto condAttrs = dyn_cast(operands[0]); + auto trueAttrs = dyn_cast(operands[1]); + auto falseAttrs = dyn_cast(operands[2]); + if (!condAttrs || !trueAttrs || !falseAttrs) + return Attribute(); + + auto elementResults = llvm::to_vector<4>(trueAttrs.getValues()); + auto iters = llvm::zip_equal(elementResults, condAttrs.getValues(), + falseAttrs.getValues()); + for (auto [result, cond, falseRes] : iters) { + if (!cond.getValue()) + result = falseRes; + } + + auto resultType = trueAttrs.getType(); + return DenseElementsAttr::get(cast(resultType), elementResults); +} + //===----------------------------------------------------------------------===// // spirv.IEqualOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir index 9fe1e532dfc77..31da59dcdc726 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir @@ -43,18 +43,18 @@ spirv.func @composite_insert_vector(%arg0: vector<3xf32>, %arg1: f32) "None" { //===----------------------------------------------------------------------===// // CHECK-LABEL: @select_scalar -spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: f32) "None" { +spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: vector<3xi32>, %arg3: f32, %arg4: f32) "None" { // CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, vector<3xi32> - %0 = spirv.Select %arg0, %arg1, %arg1 : i1, vector<3xi32> + %0 = spirv.Select %arg0, %arg1, %arg2 : i1, vector<3xi32> // CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, f32 - %1 = spirv.Select %arg0, %arg2, %arg2 : i1, f32 + %1 = spirv.Select %arg0, %arg3, %arg4 : i1, f32 spirv.Return } // CHECK-LABEL: @select_vector -spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) "None" { +spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>, %arg2: vector<2xi32>) "None" { // CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi1>, vector<2xi32> - %0 = spirv.Select %arg0, %arg1, %arg1 : vector<2xi1>, vector<2xi32> + %0 = spirv.Select %arg0, %arg1, %arg2 : vector<2xi1>, vector<2xi32> spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 1cb69891a70ed..de21d114e9fc4 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -1346,6 +1346,52 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3 // ----- +//===----------------------------------------------------------------------===// +// spirv.Select +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @convert_select_scalar +// CHECK-SAME: %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32 +func.func @convert_select_scalar(%arg1: i32, %arg2: i32) -> (i32, i32) { + %true = spirv.Constant true + %false = spirv.Constant false + %0 = spirv.Select %true, %arg1, %arg2 : i1, i32 + %1 = spirv.Select %false, %arg1, %arg2 : i1, i32 + + // CHECK: return %[[ARG1]], %[[ARG2]] + return %0, %1 : i32, i32 +} + +// CHECK-LABEL: @convert_select_vector +// CHECK-SAME: %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32> +func.func @convert_select_vector(%arg1: vector<3xi32>, %arg2: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) { + %true = spirv.Constant dense : vector<3xi1> + %false = spirv.Constant dense : vector<3xi1> + %0 = spirv.Select %true, %arg1, %arg2 : vector<3xi1>, vector<3xi32> + %1 = spirv.Select %false, %arg1, %arg2 : vector<3xi1>, vector<3xi32> + + // CHECK: return %[[ARG1]], %[[ARG2]] + return %0, %1: vector<3xi32>, vector<3xi32> +} + +// CHECK-LABEL: @convert_select_vector_extra +// CHECK-SAME: %[[CONDITIONS:.+]]: vector<2xi1>, %[[ARG1:.+]]: vector<2xi32> +func.func @convert_select_vector_extra(%conditions: vector<2xi1>, %arg1: vector<2xi32>) -> (vector<2xi32>, vector<2xi32>) { + %true_false = spirv.Constant dense<[true, false]> : vector<2xi1> + %cvec_1 = spirv.Constant dense<[42, -132]> : vector<2xi32> + %cvec_2 = spirv.Constant dense<[0, 42]> : vector<2xi32> + + // CHECK: %[[RES:.+]] = spirv.Constant dense<42> + %0 = spirv.Select %true_false, %cvec_1, %cvec_2: vector<2xi1>, vector<2xi32> + + %1 = spirv.Select %conditions, %arg1, %arg1 : vector<2xi1>, vector<2xi32> + + // CHECK: return %[[RES]], %[[ARG1]] + return %0, %1: vector<2xi32>, vector<2xi32> +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.IEqual //===----------------------------------------------------------------------===//