[mlir][spirv] Improve folding of MemRef to SPIRV Lowering#85433
Merged
Conversation
Contributor
Author
Investigate the lowering of MemRef Load/Store ops and implement additional folding of created ops Aims to improve readability of generated lowered SPIR-V code. Part of work llvm#70704
87c7b2e to
adc3bd2
Compare
Contributor
Author
|
Rebased onto required commit now that it is merged |
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Finn Plummer (inbelic) ChangesInvestigate the lowering of MemRef Load/Store ops and implement additional folding of created ops Aims to improve readability of generated lowered SPIR-V code. Part of work llvm#70704 Patch is 42.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85433.diff 8 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 0acb2142f3f68a..81b9f55cac80f7 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
assert(targetBits % sourceBits == 0);
Type type = srcIdx.getType();
IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
- auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
+ auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
- auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
- auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
- return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
+ auto srcBitsValue =
+ builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
+ auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
+ return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
}
/// Returns an adjusted spirv::AccessChainOp. Based on the
@@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
Value lastDim = op->getOperand(op.getNumOperands() - 1);
Type type = lastDim.getType();
IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
- auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
+ auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
auto indices = llvm::to_vector<4>(op.getIndices());
// There are two elements if this is a 1-D tensor.
assert(indices.size() == 2);
- indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
+ indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.getComponentPtr().getType());
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
}
@@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
return srcBool;
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
- return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
+ return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
+ zero);
}
/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
@@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
loc, builder.getIntegerType(targetBits), value);
}
- value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
+ value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
}
- return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
- offset);
+ return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
+ value, offset);
}
/// Returns true if the allocations of memref `type` generated from `allocOp`
@@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
return srcInt;
auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
- return builder.create<spirv::IEqualOp>(loc, srcInt, one);
+ return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
}
//===----------------------------------------------------------------------===//
@@ -597,13 +599,14 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// ____XXXX________ -> ____________XXXX
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
- Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
+ Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
loc, spvLoadOp.getType(), spvLoadOp, offset);
// Apply the mask to extract corresponding bits.
- Value mask = rewriter.create<spirv::ConstantOp>(
+ Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
- result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+ result =
+ rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
// Apply sign extension on the loading value unconditionally. The signedness
// semantic is carried in the operator itself, we relies other pattern to
@@ -611,11 +614,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
IntegerAttr shiftValueAttr =
rewriter.getIntegerAttr(dstType, dstBits - srcBits);
Value shiftValue =
- rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
- result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
- shiftValue);
- result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
- shiftValue);
+ rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
+ result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
+ result, shiftValue);
+ result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
+ loc, dstType, result, shiftValue);
rewriter.replaceOp(loadOp, result);
@@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
// Create a mask to clear the destination. E.g., if it is the second i8 in
// i32, 0xFFFF00FF is created.
- Value mask = rewriter.create<spirv::ConstantOp>(
+ Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
- Value clearBitsMask =
- rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
- clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
+ Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
+ loc, dstType, mask, offset);
+ clearBitsMask =
+ rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
@@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
- return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
+ return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
}();
rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e85..4072608dc8f873 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -991,15 +991,16 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
// broken down into progressive small steps so we can have intermediate steps
// using other dialects. At the moment SPIR-V is the final sink.
- Value linearizedIndex = builder.create<spirv::ConstantOp>(
+ Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
loc, integerType, IntegerAttr::get(integerType, offset));
for (const auto &index : llvm::enumerate(indices)) {
- Value strideVal = builder.create<spirv::ConstantOp>(
+ Value strideVal = builder.createOrFold<spirv::ConstantOp>(
loc, integerType,
IntegerAttr::get(integerType, strides[index.index()]));
- Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+ Value update =
+ builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
linearizedIndex =
- builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
+ builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
}
return linearizedIndex;
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index fa12da8ef9d4ec..4339799ccd5eaf 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -60,13 +60,9 @@ module attributes {
// CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
%13 = arith.addi %arg4, %3 : index
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
- // CHECK: %[[OFFSET1_0:.*]] = spirv.Constant 0 : i32
// CHECK: %[[STRIDE1_1:.*]] = spirv.Constant 4 : i32
- // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
- // CHECK: %[[OFFSET1_1:.*]] = spirv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
- // CHECK: %[[STRIDE1_2:.*]] = spirv.Constant 1 : i32
- // CHECK: %[[UPDATE1_2:.*]] = spirv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
- // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
+ // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[INDEX1]], %[[STRIDE1_1]] : i32
+ // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[INDEX2]], %[[UPDATE1_1]] : i32
// CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
// CHECK-NEXT: %[[VAL1:.*]] = spirv.Load "StorageBuffer" %[[PTR1]]
%14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>
diff --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
index 470c8531e2e0fb..52ed14e8cce233 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
@@ -12,16 +12,10 @@ module attributes {
// CHECK-LABEL: @load_i1
func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 {
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
// CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// CHECK: %[[T4:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -37,32 +31,20 @@ func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1
// INDEX64-LABEL: @load_i8
func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
// CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
// CHECK: builtin.unrealized_conversion_cast %[[SR]]
// INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
- // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
- // INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
+ // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
// INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32
- // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
- // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
- // INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
// INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
- // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
// INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -76,15 +58,12 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 {
// CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[UPDATE:.+]] = spirv.IMul %[[ONE]], %[[ARG1_CAST]] : i32
- // CHECK: %[[FLAT_IDX:.+]] = spirv.IAdd %[[ZERO]], %[[UPDATE]] : i32
// CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[FLAT_IDX]], %[[TWO]] : i32
+ // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32
// CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
// CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[FLAT_IDX]], %[[TWO]] : i32
+ // CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32
// CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32
// CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32
@@ -110,20 +89,12 @@ func.func @load_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %value: i1) {
// CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
- // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // CHECK: %[[MASK:.+]] = spirv.Constant -256 : i32
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
// CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32
- // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CASTED_ARG1]], %[[OFFSET]] : i32, i32
- // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
// CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CASTED_ARG1]]
memref.store %value, %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>>
return
}
@@ -136,36 +107,22 @@ func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %val
// CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
// CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
- // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // CHECK: %[[MASK2:.+]] = spirv.Constant -256 : i32
// CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
- // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
- // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
- // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
+ // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+ // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
// INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
// INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
- // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
- // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
- // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
// INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
- // INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // INDEX64: %[[MASK2:.+]] = spirv.Constant -256 : i32
// INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
- // INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
- // INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
- // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
+ // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+ // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
return
}
@@ -177,19 +...
[truncated]
|
kuhar
approved these changes
Mar 21, 2024
kuhar
left a comment
Member
There was a problem hiding this comment.
Oh wow, it got so much more concise now. Thanks!
chencha3
pushed a commit
to chencha3/llvm-project
that referenced
this pull request
Mar 23, 2024
Investigate the lowering of MemRef Load/Store ops and implement additional folding of created ops Aims to improve readability of generated lowered SPIR-V code. Part of work llvm#70704
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Investigate the lowering of MemRef Load/Store ops and implement additional folding of created ops
Aims to improve readability of generated lowered SPIR-V code.
Part of work #70704