[mlir][bufferization] Fix crash with copy-before-write + bufferize-function-boundaries#186446
Conversation
…nction-boundaries When `copy-before-write=1` is combined with `bufferize-function-boundaries=1`, `bufferizeOp` creates a plain `AnalysisState` (not `OneShotAnalysisState`) and passes it to `insertTensorCopies`. Walking `CallOp`s during conflict resolution called `getCalledFunction(callOp, state)`, which unconditionally cast the `AnalysisState` to `OneShotAnalysisState` via `static_cast`, causing UB and a stack overflow crash. Fix by guarding the cast with `isa<OneShotAnalysisState>()` so that when the state is a plain `AnalysisState`, the function falls through to building a fresh `SymbolTableCollection` — the same safe fallback already present. Fixes llvm#163052 Assisted-by: Claude Code
|
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesWhen Fix by guarding the cast with Fixes #163052 Assisted-by: Claude Code Full diff: https://github.com/llvm/llvm-project/pull/186446.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index e43ab54a048b9..3aaa38272935d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -101,12 +101,14 @@ static FuncOp getCalledFunction(CallOpInterface callOp,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp,
const AnalysisState &state) {
- auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);
-
- if (auto *funcAnalysisState =
- oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
- // Use the cached symbol tables.
- return getCalledFunction(callOp, funcAnalysisState->symbolTables);
+ if (isa<OneShotAnalysisState>(state)) {
+ auto &oneShotAnalysisState =
+ static_cast<const OneShotAnalysisState &>(state);
+ if (auto *funcAnalysisState =
+ oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
+ // Use the cached symbol tables.
+ return getCalledFunction(callOp, funcAnalysisState->symbolTables);
+ }
}
SymbolTableCollection symbolTables;
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-call-copy-before-write.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-call-copy-before-write.mlir
new file mode 100644
index 0000000000000..7addca2c9d6a5
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-call-copy-before-write.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 copy-before-write=1" | FileCheck %s
+
+// Regression test for https://github.com/llvm/llvm-project/issues/163052
+// copy-before-write=1 + bufferize-function-boundaries=1 with a call to a
+// private (declaration-only) function used to crash with a stack overflow due
+// to an invalid cast of AnalysisState to OneShotAnalysisState inside
+// getCalledFunction().
+
+// CHECK-LABEL: func.func private @callee(memref<64xf32
+// CHECK-LABEL: func.func @caller
+// CHECK: call @callee
+func.func private @callee(tensor<64xf32>)
+func.func @caller(%A : tensor<64xf32>) {
+ call @callee(%A) : (tensor<64xf32>) -> ()
+ return
+}
|
|
Ping @matthias-springer |
When
copy-before-write=1is combined withbufferize-function-boundaries=1,bufferizeOpcreates a plainAnalysisState(notOneShotAnalysisState) and passes it toinsertTensorCopies. WalkingCallOps during conflict resolution calledgetCalledFunction(callOp, state), which unconditionally cast theAnalysisStatetoOneShotAnalysisStateviastatic_cast, causing UB and a stack overflow crash.Fix by guarding the cast with
isa<OneShotAnalysisState>()so that when the state is a plainAnalysisState, the function falls through to building a freshSymbolTableCollection— the same safe fallback already present.Fixes #163052
Assisted-by: Claude Code