diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index d9144d0c5e228..5908db3c04817 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -441,6 +441,63 @@ class ConstantScalarAndVectorPattern } }; +/// Converts `spirv.SpecConstant` to an `llvm.mlir.global` constant with a +/// private linkage, using the spec constant's default value as the initializer. +/// When lowering to LLVM there is no notion of specialization, so the default +/// value is used unconditionally. +class SpecConstantPattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::SpecConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto defaultValue = cast(op.getDefaultValue()); + auto srcType = defaultValue.getType(); + auto dstType = getTypeConverter()->convertType(srcType); + if (!dstType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + // Handle signed/unsigned integers: strip the sign by converting to a + // signless integer type (analogous to ConstantScalarAndVectorPattern). + Attribute initializer = defaultValue; + if (isSignedIntegerOrVector(srcType) || + isUnsignedIntegerOrVector(srcType)) { + auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); + dstType = getTypeConverter()->convertType(signlessType); + initializer = rewriter.getIntegerAttr( + signlessType, cast(defaultValue).getValue()); + } + + rewriter.replaceOpWithNewOp( + op, dstType, /*isConstant=*/true, LLVM::Linkage::Private, + op.getSymName(), initializer); + return success(); + } +}; + +/// Converts `spirv.mlir.referenceof` (referencing a `spirv.SpecConstant`) to a +/// load from the corresponding `llvm.mlir.global`. +class ReferenceOfPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::ReferenceOfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + auto ptrType = LLVM::LLVMPointerType::get(op.getContext()); + Value addr = LLVM::AddressOfOp::create(rewriter, op.getLoc(), ptrType, + op.getSpecConst()); + rewriter.replaceOpWithNewOp(op, dstType, addr); + return success(); + } +}; + class BitFieldSExtractPattern : public SPIRVToLLVMConversion { public: @@ -1850,8 +1907,8 @@ void mlir::populateSPIRVToLLVMConversionPatterns( IComparePattern, IComparePattern, - // Constant op - ConstantScalarAndVectorPattern, + // Constant ops + ConstantScalarAndVectorPattern, SpecConstantPattern, ReferenceOfPattern, // Control Flow ops BranchConversionPattern, BranchConditionalConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir index 2d74022b34406..7066b9fe1f388 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir @@ -59,3 +59,31 @@ spirv.func @float_constant_vector() "None" { %0 = spirv.Constant dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32> spirv.Return } + +//===----------------------------------------------------------------------===// +// spirv.SpecConstant and spirv.mlir.referenceof +//===----------------------------------------------------------------------===// + +// CHECK: llvm.mlir.global private constant @sc_int(-5 : i32) {{.*}} : i32 +// CHECK: llvm.mlir.global private constant @sc_signed(-5 : i32) {{.*}} : i32 +// CHECK: llvm.mlir.global private constant @sc_unsigned(10 : i16) {{.*}} : i16 +// CHECK: llvm.mlir.global private constant @sc_float(3.140000e+00 : f32) {{.*}} : f32 +// CHECK: llvm.mlir.global private constant @sc_bool(true) {{.*}} : i1 +spirv.module Logical GLSL450 { + spirv.SpecConstant @sc_int = -5 : i32 + spirv.SpecConstant @sc_signed = -5 : si32 + spirv.SpecConstant @sc_unsigned = 10 : ui16 + spirv.SpecConstant @sc_float = 3.14 : f32 + spirv.SpecConstant @sc_bool = true +} + +// CHECK-LABEL: @use_spec_consts +spirv.module Logical GLSL450 { + spirv.SpecConstant @sc = 42 : i32 + spirv.func @use_spec_consts() -> i32 "None" { + // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @sc : !llvm.ptr + // CHECK: llvm.load %[[ADDR]] : !llvm.ptr -> i32 + %0 = spirv.mlir.referenceof @sc : i32 + spirv.ReturnValue %0 : i32 + } +}