diff --git a/mlir/lib/AsmParser/AffineParser.cpp b/mlir/lib/AsmParser/AffineParser.cpp index 1797611858c06..0d0c74b965a2b 100644 --- a/mlir/lib/AsmParser/AffineParser.cpp +++ b/mlir/lib/AsmParser/AffineParser.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/raw_ostream.h" #include #include +#include #include using namespace mlir; @@ -345,7 +346,14 @@ AffineExpr AffineParser::parseSymbolSSAIdExpr() { /// affine-expr ::= integer-literal AffineExpr AffineParser::parseIntegerExpr() { auto val = getToken().getUInt64IntegerValue(); - if (!val.has_value() || (int64_t)*val < 0) + // Allow 9223372036854775808 (= 2^63 = |INT64_MIN|) because the printer + // emits it as the magnitude in "... - 9223372036854775808" to represent + // affine expressions containing INT64_MIN (e.g. "d0 + INT64_MIN" is + // printed as "d0 - 9223372036854775808"). The cast to int64_t yields + // INT64_MIN, which is the correct internal representation. + if (!val.has_value() || + (static_cast(*val) < 0 && + *val != static_cast(std::numeric_limits::min()))) return emitError("constant too large for index"), nullptr; consumeToken(Token::integer); diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index da91066815ca4..05d643d36896c 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -906,8 +906,13 @@ AffineExpr AffineExpr::operator-() const { } // Delegate to operator+. -AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); } +AffineExpr AffineExpr::operator-(int64_t v) const { + // Use unsigned negation to avoid signed integer overflow for INT64_MIN. + return *this + static_cast(-static_cast(v)); +} AffineExpr AffineExpr::operator-(AffineExpr other) const { + if (auto constOther = dyn_cast(other)) + return *this - constOther.getValue(); return *this + (-other); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index b3242f838fc1d..c07970c7261e4 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -3234,7 +3234,9 @@ void AsmPrinter::Impl::printAffineExprInternal( os << " - "; printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, printValueName); - os << " * " << -rrhs.getValue(); + // Use unsigned negation to avoid signed integer overflow for + // INT64_MIN. + os << " * " << -static_cast(rrhs.getValue()); if (enclosingTightness == BindingStrength::Strong) os << ')'; return; @@ -3247,7 +3249,8 @@ void AsmPrinter::Impl::printAffineExprInternal( if (auto rhsConst = dyn_cast(rhsExpr)) { if (rhsConst.getValue() < 0) { printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); - os << " - " << -rhsConst.getValue(); + // Use unsigned negation to avoid signed integer overflow for INT64_MIN. + os << " - " << -static_cast(rhsConst.getValue()); if (enclosingTightness == BindingStrength::Strong) os << ')'; return; diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir index 86bdaafd79f32..a7cba498b4146 100644 --- a/mlir/test/IR/affine-map.mlir +++ b/mlir/test/IR/affine-map.mlir @@ -228,6 +228,8 @@ // CHECK: #map{{[0-9]*}} = affine_map<()[s0, s1] -> (0)> #map69 = affine_map<()[s0, s1] -> ((s0 + s1) mod (s0 + s1))> +// CHECK: #[[INT64_MIN_MAP:map[0-9]*]] = affine_map<(d0) -> (d0 - 9223372036854775808)> + // Single identity maps are removed. // CHECK: @f0(memref<2x4xi8, 1>) func.func private @f0(memref<2x4xi8, #map0, 1>) @@ -448,3 +450,16 @@ func.func private @f56(memref<1x1xi8, #map56>) // CHECK: "f69"() {map = #map{{[0-9]*}}} : () -> () "f69"() {map = #map69} : () -> () + +// Test that affine expressions with INT64_MIN as a constant are printed +// without signed integer overflow (printed as "d0 - 9223372036854775808") +// and can be parsed back (round-trip). The value INT64_MIN arises when +// affine simplification sums two large negative constants. +#map_int64min = affine_map<(d0) -> (d0 + (-4611686018427387904) + (-4611686018427387904))> +// CHECK: "f_int64min"() {map = #[[INT64_MIN_MAP]]} : () -> () +"f_int64min"() {map = #map_int64min} : () -> () + +// Round-trip: parse a map already containing "d0 - 9223372036854775808". +#map_int64min_rt = affine_map<(d0) -> (d0 - 9223372036854775808)> +// CHECK: "f_int64min_rt"() {map = #[[INT64_MIN_MAP]]} : () -> () +"f_int64min_rt"() {map = #map_int64min_rt} : () -> ()