diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 6979f34c1e047..8def84fc49378 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -1551,6 +1551,19 @@ void SwitchOp::getSuccessorRegions( llvm::append_range(successors, getRegions()); } +/// Returns the int64_t value of an IntegerAttr regardless of whether its type +/// is signless, signed, or unsigned. Returns std::nullopt for unknown types. +static std::optional getIntAttrValue(IntegerAttr attr) { + Type type = attr.getType(); + if (type.isIndex() || type.isSignlessInteger()) + return attr.getInt(); + if (type.isSignedInteger()) + return attr.getSInt(); + if (type.isUnsignedInteger()) + return static_cast(attr.getUInt()); + return std::nullopt; +} + void SwitchOp::getEntrySuccessorRegions( ArrayRef operands, SmallVectorImpl &successors) { @@ -1563,10 +1576,17 @@ void SwitchOp::getEntrySuccessorRegions( return; } + std::optional argValue = getIntAttrValue(arg); + if (!argValue) { + // Unknown type; conservatively treat all regions as possible. + llvm::append_range(successors, getRegions()); + return; + } + // Otherwise, try to find a case with a matching value. If not, the // default region is the only successor. for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) { - if (caseValue == arg.getInt()) { + if (caseValue == *argValue) { successors.emplace_back(&caseRegion); return; } @@ -1583,8 +1603,15 @@ void SwitchOp::getRegionInvocationBounds( return; } + std::optional maybeIntValue = getIntAttrValue(operandValue); + if (!maybeIntValue) { + // Unknown type; conservatively treat all regions as possible. + bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1)); + return; + } + unsigned liveIndex = getNumRegions() - 1; - const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt()); + const auto *iteratorToInt = llvm::find(getCases(), *maybeIntValue); liveIndex = iteratorToInt != getCases().end() ? std::distance(getCases().begin(), iteratorToInt) diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir index 138acbb42d1a9..f5a5183593d03 100644 --- a/mlir/test/Transforms/sccp.mlir +++ b/mlir/test/Transforms/sccp.mlir @@ -286,3 +286,24 @@ func.func @no_crash_acc_kernel_environment(%data: memref<8xi32>) { } return } + +// ----- + +// Regression test for https://github.com/llvm/llvm-project/issues/187973 +// SwitchOp::getEntrySuccessorRegions must not call IntegerAttr::getInt() on +// an unsigned integer type — that function asserts signless/index only. + +// CHECK-LABEL: no_crash_emitc_switch_unsigned_condition +func.func @no_crash_emitc_switch_unsigned_condition() { + // CHECK: emitc.constant + %0 = "emitc.constant"() {value = 1 : ui32} : () -> ui32 + // CHECK: emitc.switch + emitc.switch %0 : ui32 + case 2 { + emitc.yield + } + default { + emitc.yield + } + return +}