[MLIR][SPIRV] Add conversion for spirv.SpecConstant and spirv.mlir.referenceof to LLVM#188746
[MLIR][SPIRV] Add conversion for spirv.SpecConstant and spirv.mlir.referenceof to LLVM#188746joker-eph wants to merge 1 commit into
Conversation
…ferenceof to LLVM `-convert-spirv-to-llvm` failed with "failed to legalize operation 'spirv.SpecConstant'" because no conversion pattern existed for `spirv.SpecConstant` or `spirv.mlir.referenceof`. Add two new conversion patterns: - `SpecConstantPattern`: converts `spirv.SpecConstant` to an `llvm.mlir.global` constant (private linkage) using the spec constant's default value as the initializer. Signed/unsigned integer types are converted to signless integers, consistent with `ConstantScalarAndVectorPattern`. - `ReferenceOfPattern`: converts `spirv.mlir.referenceof` to an `llvm.mlir.addressof` + `llvm.load` pair, loading the value from the corresponding global constant. Since LLVM IR has no notion of specialization constants, the default value is used unconditionally when lowering. Fixes llvm#159485 Assisted-by: Claude Code
|
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) Changes
Add two new conversion patterns:
Since LLVM IR has no notion of specialization constants, the default value is used unconditionally when lowering. Fixes #159485 Assisted-by: Claude Code Full diff: https://github.com/llvm/llvm-project/pull/188746.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index d9144d0c5e228..5908db3c04817 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -441,6 +441,63 @@ class ConstantScalarAndVectorPattern
}
};
+/// Converts `spirv.SpecConstant` to an `llvm.mlir.global` constant with a
+/// private linkage, using the spec constant's default value as the initializer.
+/// When lowering to LLVM there is no notion of specialization, so the default
+/// value is used unconditionally.
+class SpecConstantPattern
+ : public SPIRVToLLVMConversion<spirv::SpecConstantOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::SpecConstantOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::SpecConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto defaultValue = cast<TypedAttr>(op.getDefaultValue());
+ auto srcType = defaultValue.getType();
+ auto dstType = getTypeConverter()->convertType(srcType);
+ if (!dstType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ // Handle signed/unsigned integers: strip the sign by converting to a
+ // signless integer type (analogous to ConstantScalarAndVectorPattern).
+ Attribute initializer = defaultValue;
+ if (isSignedIntegerOrVector(srcType) ||
+ isUnsignedIntegerOrVector(srcType)) {
+ auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
+ dstType = getTypeConverter()->convertType(signlessType);
+ initializer = rewriter.getIntegerAttr(
+ signlessType, cast<IntegerAttr>(defaultValue).getValue());
+ }
+
+ rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
+ op, dstType, /*isConstant=*/true, LLVM::Linkage::Private,
+ op.getSymName(), initializer);
+ return success();
+ }
+};
+
+/// Converts `spirv.mlir.referenceof` (referencing a `spirv.SpecConstant`) to a
+/// load from the corresponding `llvm.mlir.global`.
+class ReferenceOfPattern : public SPIRVToLLVMConversion<spirv::ReferenceOfOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::ReferenceOfOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::ReferenceOfOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ auto ptrType = LLVM::LLVMPointerType::get(op.getContext());
+ Value addr = LLVM::AddressOfOp::create(rewriter, op.getLoc(), ptrType,
+ op.getSpecConst());
+ rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dstType, addr);
+ return success();
+ }
+};
+
class BitFieldSExtractPattern
: public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
public:
@@ -1850,8 +1907,8 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
- // Constant op
- ConstantScalarAndVectorPattern,
+ // Constant ops
+ ConstantScalarAndVectorPattern, SpecConstantPattern, ReferenceOfPattern,
// Control Flow ops
BranchConversionPattern, BranchConditionalConversionPattern,
diff --git a/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir
index 2d74022b34406..7066b9fe1f388 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir
@@ -59,3 +59,31 @@ spirv.func @float_constant_vector() "None" {
%0 = spirv.Constant dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>
spirv.Return
}
+
+//===----------------------------------------------------------------------===//
+// spirv.SpecConstant and spirv.mlir.referenceof
+//===----------------------------------------------------------------------===//
+
+// CHECK: llvm.mlir.global private constant @sc_int(-5 : i32) {{.*}} : i32
+// CHECK: llvm.mlir.global private constant @sc_signed(-5 : i32) {{.*}} : i32
+// CHECK: llvm.mlir.global private constant @sc_unsigned(10 : i16) {{.*}} : i16
+// CHECK: llvm.mlir.global private constant @sc_float(3.140000e+00 : f32) {{.*}} : f32
+// CHECK: llvm.mlir.global private constant @sc_bool(true) {{.*}} : i1
+spirv.module Logical GLSL450 {
+ spirv.SpecConstant @sc_int = -5 : i32
+ spirv.SpecConstant @sc_signed = -5 : si32
+ spirv.SpecConstant @sc_unsigned = 10 : ui16
+ spirv.SpecConstant @sc_float = 3.14 : f32
+ spirv.SpecConstant @sc_bool = true
+}
+
+// CHECK-LABEL: @use_spec_consts
+spirv.module Logical GLSL450 {
+ spirv.SpecConstant @sc = 42 : i32
+ spirv.func @use_spec_consts() -> i32 "None" {
+ // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @sc : !llvm.ptr
+ // CHECK: llvm.load %[[ADDR]] : !llvm.ptr -> i32
+ %0 = spirv.mlir.referenceof @sc : i32
+ spirv.ReturnValue %0 : i32
+ }
+}
|
IgWod
left a comment
There was a problem hiding this comment.
I wonder whether we should translate them into calls to __spirv_SpecConstant which seems to be a LLVM IR way of encoding SpecConstants: https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/SPIRV/constant/spec-constant.ll ?
|
@IgWod : yes this would probably be better (to more faithfully target the LLVM SPIRV backend), but for just general lowering to LLVM I would think this lowering would work? The problem of the intrinsics you mention is two-fold I think:
|
IgWod
left a comment
There was a problem hiding this comment.
@IgWod : yes this would probably be better (to more faithfully target the LLVM SPIRV backend), but for just general lowering to LLVM I would think this lowering would work?
Yes, that's reasonable, thanks! I'm happy with the proposed conversion, but I'd prefer to wait for @kuhar to approve this one.
I also added few suggestions for autos, but otherwise LGTM.
| matchAndRewrite(spirv::SpecConstantOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto defaultValue = cast<TypedAttr>(op.getDefaultValue()); | ||
| auto srcType = defaultValue.getType(); |
There was a problem hiding this comment.
| auto srcType = defaultValue.getType(); | |
| Type srcType = defaultValue.getType(); |
| ConversionPatternRewriter &rewriter) const override { | ||
| auto defaultValue = cast<TypedAttr>(op.getDefaultValue()); | ||
| auto srcType = defaultValue.getType(); | ||
| auto dstType = getTypeConverter()->convertType(srcType); |
There was a problem hiding this comment.
| auto dstType = getTypeConverter()->convertType(srcType); | |
| Type dstType = getTypeConverter()->convertType(srcType); |
| LogicalResult | ||
| matchAndRewrite(spirv::ReferenceOfOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto dstType = getTypeConverter()->convertType(op.getType()); |
There was a problem hiding this comment.
| auto dstType = getTypeConverter()->convertType(op.getType()); | |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
-convert-spirv-to-llvmfailed with "failed to legalize operation 'spirv.SpecConstant'" because no conversion pattern existed forspirv.SpecConstantorspirv.mlir.referenceof.Add two new conversion patterns:
SpecConstantPattern: convertsspirv.SpecConstantto anllvm.mlir.globalconstant (private linkage) using the spec constant's default value as the initializer. Signed/unsigned integer types are converted to signless integers, consistent withConstantScalarAndVectorPattern.ReferenceOfPattern: convertsspirv.mlir.referenceofto anllvm.mlir.addressof+llvm.loadpair, loading the value from the corresponding global constant.Since LLVM IR has no notion of specialization constants, the default value is used unconditionally when lowering.
Fixes #159485
Assisted-by: Claude Code