[MLIR][EmitC] Fix crash in SwitchOp::getEntrySuccessorRegions on unsigned integer type#188546
Conversation
|
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesSwitchOp::getEntrySuccessorRegions and getRegionInvocationBounds called IntegerAttr::getInt() to retrieve the constant switch argument, but getInt() asserts that the attribute type must be a signless integer or index. For unsigned integer types (e.g. ui32), this assertion fired and crashed the process. Fix by selecting the appropriate accessor based on the attribute type: getInt() for signless/index, getSInt() for signed, and getUInt() (cast to int64_t) for unsigned integer types. Unknown types fall back to the conservative "all regions possible" path. The same fix is applied to getRegionInvocationBounds, which had an identical call to getInt(). Fixes #187973 Assisted-by: Claude Code Full diff: https://github.com/llvm/llvm-project/pull/188546.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 6979f34c1e047..09fd04c5a2d99 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1563,10 +1563,27 @@ void SwitchOp::getEntrySuccessorRegions(
return;
}
+ // Compute the integer value of the argument. Case labels are stored as
+ // int64_t; compute the arg value using the appropriate accessor to avoid
+ // asserting on signed or unsigned integer types.
+ int64_t argValue;
+ Type argType = arg.getType();
+ if (argType.isIndex() || argType.isSignlessInteger())
+ argValue = arg.getInt();
+ else if (argType.isSignedInteger())
+ argValue = arg.getSInt();
+ else if (argType.isUnsignedInteger())
+ argValue = static_cast<int64_t>(arg.getUInt());
+ else {
+ // Unknown type; conservatively treat all regions as possible.
+ llvm::append_range(successors, getRegions());
+ return;
+ }
+
// Otherwise, try to find a case with a matching value. If not, the
// default region is the only successor.
for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
- if (caseValue == arg.getInt()) {
+ if (caseValue == argValue) {
successors.emplace_back(&caseRegion);
return;
}
@@ -1583,8 +1600,24 @@ void SwitchOp::getRegionInvocationBounds(
return;
}
+ // Compute the integer value of the operand using the appropriate accessor.
+ Type operandType = operandValue.getType();
+ std::optional<int64_t> maybeIntValue;
+ if (operandType.isIndex() || operandType.isSignlessInteger())
+ maybeIntValue = operandValue.getInt();
+ else if (operandType.isSignedInteger())
+ maybeIntValue = operandValue.getSInt();
+ else if (operandType.isUnsignedInteger())
+ maybeIntValue = static_cast<int64_t>(operandValue.getUInt());
+
+ if (!maybeIntValue) {
+ // Unknown type; conservatively treat all regions as possible.
+ bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
+ return;
+ }
+
unsigned liveIndex = getNumRegions() - 1;
- const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
+ const auto *iteratorToInt = llvm::find(getCases(), *maybeIntValue);
liveIndex = iteratorToInt != getCases().end()
? std::distance(getCases().begin(), iteratorToInt)
diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir
index c78c8594c0ba5..949e5ad186168 100644
--- a/mlir/test/Transforms/sccp.mlir
+++ b/mlir/test/Transforms/sccp.mlir
@@ -255,3 +255,24 @@ func.func @no_crash_with_different_source_type() {
%1 = vector.broadcast %0 : i64 to vector<128xi64>
llvm.return
}
+
+// -----
+
+// Regression test for https://github.com/llvm/llvm-project/issues/187973
+// SwitchOp::getEntrySuccessorRegions must not call IntegerAttr::getInt() on
+// an unsigned integer type — that function asserts signless/index only.
+
+// CHECK-LABEL: no_crash_emitc_switch_unsigned_condition
+func.func @no_crash_emitc_switch_unsigned_condition() {
+ // CHECK: emitc.constant
+ %0 = "emitc.constant"() {value = 1 : ui32} : () -> ui32
+ // CHECK: emitc.switch
+ emitc.switch %0 : ui32
+ case 2 {
+ emitc.yield
+ }
+ default {
+ emitc.yield
+ }
+ return
+}
|
| if (operandType.isIndex() || operandType.isSignlessInteger()) | ||
| maybeIntValue = operandValue.getInt(); | ||
| else if (operandType.isSignedInteger()) | ||
| maybeIntValue = operandValue.getSInt(); | ||
| else if (operandType.isUnsignedInteger()) | ||
| maybeIntValue = static_cast<int64_t>(operandValue.getUInt()); |
There was a problem hiding this comment.
Can the two instances of this logic be outlined into a function?
…gned integer type SwitchOp::getEntrySuccessorRegions and getRegionInvocationBounds called IntegerAttr::getInt() to retrieve the constant switch argument, but getInt() asserts that the attribute type must be a signless integer or index. For unsigned integer types (e.g. ui32), this assertion fired and crashed the process. Fix by selecting the appropriate accessor based on the attribute type: getInt() for signless/index, getSInt() for signed, and getUInt() (cast to int64_t) for unsigned integer types. Unknown types fall back to the conservative "all regions possible" path. The same fix is applied to getRegionInvocationBounds, which had an identical call to getInt(). Fixes llvm#187973 Assisted-by: Claude Code
21722f2 to
ededfe2
Compare
…gned integer type (llvm#188546) SwitchOp::getEntrySuccessorRegions and getRegionInvocationBounds called IntegerAttr::getInt() to retrieve the constant switch argument, but getInt() asserts that the attribute type must be a signless integer or index. For unsigned integer types (e.g. ui32), this assertion fired and crashed the process. Fix by selecting the appropriate accessor based on the attribute type: getInt() for signless/index, getSInt() for signed, and getUInt() (cast to int64_t) for unsigned integer types. Unknown types fall back to the conservative "all regions possible" path. The same fix is applied to getRegionInvocationBounds, which had an identical call to getInt(). Fixes llvm#187973 Assisted-by: Claude Code
…gned integer type (llvm#188546) SwitchOp::getEntrySuccessorRegions and getRegionInvocationBounds called IntegerAttr::getInt() to retrieve the constant switch argument, but getInt() asserts that the attribute type must be a signless integer or index. For unsigned integer types (e.g. ui32), this assertion fired and crashed the process. Fix by selecting the appropriate accessor based on the attribute type: getInt() for signless/index, getSInt() for signed, and getUInt() (cast to int64_t) for unsigned integer types. Unknown types fall back to the conservative "all regions possible" path. The same fix is applied to getRegionInvocationBounds, which had an identical call to getInt(). Fixes llvm#187973 Assisted-by: Claude Code
SwitchOp::getEntrySuccessorRegions and getRegionInvocationBounds called IntegerAttr::getInt() to retrieve the constant switch argument, but getInt() asserts that the attribute type must be a signless integer or index. For unsigned integer types (e.g. ui32), this assertion fired and crashed the process.
Fix by selecting the appropriate accessor based on the attribute type: getInt() for signless/index, getSInt() for signed, and getUInt() (cast to int64_t) for unsigned integer types. Unknown types fall back to the conservative "all regions possible" path.
The same fix is applied to getRegionInvocationBounds, which had an identical call to getInt().
Fixes #187973
Assisted-by: Claude Code