Skip to content

[mlir][linalg] Fix crash when folding tensor.cast into unpack using static packed shape for inner tiles#188000

Merged
hockyy merged 1 commit into
llvm:mainfrom
hockyy:bail-tensor-pack
Apr 22, 2026
Merged

[mlir][linalg] Fix crash when folding tensor.cast into unpack using static packed shape for inner tiles#188000
hockyy merged 1 commit into
llvm:mainfrom
hockyy:bail-tensor-pack

Conversation

@hockyy

@hockyy hockyy commented Mar 23, 2026

Copy link
Copy Markdown
Member

This change fixes #187975 and #188405, a crash in Linalg tensor-cast folding for pack/unpack when tile sizes are dynamic or otherwise not provably constant.

Previously, canonicalization could reach getNewMixedTileSizes and unconditionally access getConstantIntValue(tile).value(). For dynamic tile operands, that value can be absent, causing std::bad_optional_access/assert aborts.

When folding tensor.cast into linalg.unpack (and the same helper is used for linalg.pack), mixed inner tile sizes are updated from the refined packed tensor type. Every static trailing packed dimension gets a matching static tile attribute, replacing SSA tile values and overwriting tile constants that disagreed with that type.

Dynamic packed dimensions still keep the original tile operands.

Assisted-by: CLion code completion

@llvmbot

llvmbot commented Mar 23, 2026

Copy link
Copy Markdown
Member

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Hocky Yudhiono (hockyy)

Changes

This change fixes #187975, a crash in Linalg tensor-cast folding for pack/unpack when tile sizes are dynamic or otherwise not provably constant.

Previously, canonicalization could reach getNewMixedTileSizes and unconditionally access getConstantIntValue(tile).value(). For dynamic tile operands, that value can be absent, causing std::bad_optional_access/assert aborts.


Full diff: https://github.com/llvm/llvm-project/pull/188000.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+27-14)
  • (added) mlir/test/Dialect/Linalg/canonicalize-dynamic-unpack-tile.mlir (+71)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ad2909f656eea..95aeb821c51d0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5000,8 +5000,10 @@ template SmallVector<int64_t>
 //  * a dim from newPackedTy is static, and
 //  * the corresponding size from mixedTiles is still dynamic.
 // Otherwise, the original tile size is preserved.
+// Returns failure when a dynamic tile cannot be proven to match the static
+// packed dim.
 // Note - packed-type-dim and mixed-tile-size should always match!
-static SmallVector<OpFoldResult>
+static FailureOr<SmallVector<OpFoldResult>>
 getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
                      SmallVector<OpFoldResult> mixedTiles) {
   SmallVector<OpFoldResult> newMixedTileSizes;
@@ -5015,17 +5017,21 @@ getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
       continue;
     }
 
-    // If the current result dim is static, update the dynamic mixed-size
-    // (provided the original value is dynamic).
+    // If the current result dim is static, update the dynamic mixed-size only
+    // when the original dynamic value is a known constant matching `shape`.
+    // Otherwise, bail out and let the fold fail conservatively.
     OpFoldResult tile = std::get<1>(it);
     if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
       // Already a constant
       newMixedTileSizes.push_back(tile);
     } else {
-      assert(getConstantIntValue(tile).value() == shape &&
-             "tile size and dim size don't match!");
-      newMixedTileSizes.push_back(
-          (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+      std::optional<int64_t> constTile = getConstantIntValue(tile);
+      if (constTile.has_value() && constTile.value() == shape) {
+        newMixedTileSizes.push_back(
+            rewriter.getIntegerAttr(rewriter.getIndexType(), shape));
+      } else {
+        return failure();
+      }
     }
   }
 
@@ -5995,8 +6001,11 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
         tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
 
     // Get the updated mixed-tile-sizes attribute.
-    SmallVector<OpFoldResult> newMixedTileSizes =
+    FailureOr<SmallVector<OpFoldResult>> newMixedTileSizes =
         getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
+    if (failed(newMixedTileSizes))
+      return rewriter.notifyMatchFailure(
+          op, "unable to prove dynamic tile sizes after folding tensor.cast");
 
     // Clone op.
     // TODO: Strictly speaking, discardable attributes should be _discarded_ at
@@ -6004,7 +6013,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
     // to preserve. Implement a better abstraction.
     PackOp newOp =
         PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
-                       op.getInnerDimsPos(), newMixedTileSizes,
+                       op.getInnerDimsPos(), newMixedTileSizes.value(),
                        op.getPaddingValue(), op.getOuterDimsPerm());
     newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
 
@@ -6476,16 +6485,20 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
     Value sourceTensor = newOperands[0];
 
     // Get the updated mixed-tile-sizes attribute.
-    SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
-        rewriter, sourceTensor.getType(), op.getMixedTiles());
+    FailureOr<SmallVector<OpFoldResult>> newMixedTileSizes =
+        getNewMixedTileSizes(rewriter, sourceTensor.getType(), op.getMixedTiles());
+    if (failed(newMixedTileSizes))
+      return rewriter.notifyMatchFailure(
+          op, "unable to prove dynamic tile sizes after folding tensor.cast");
 
     // Clone op.
     // TODO: Strictly speaking, discardable attributes should be _discarded_ at
     // this point. However, in practice, we use them for things that we'd like
     // to preserve. Implement a better abstraction.
-    UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
-                                      newOperands[1], op.getInnerDimsPos(),
-                                      newMixedTileSizes, op.getOuterDimsPerm());
+    UnPackOp newOp =
+        UnPackOp::create(rewriter, op.getLoc(), sourceTensor, newOperands[1],
+                         op.getInnerDimsPos(), newMixedTileSizes.value(),
+                         op.getOuterDimsPerm());
     newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
 
     // Replace op.
diff --git a/mlir/test/Dialect/Linalg/canonicalize-dynamic-unpack-tile.mlir b/mlir/test/Dialect/Linalg/canonicalize-dynamic-unpack-tile.mlir
new file mode 100644
index 0000000000000..48e2ae90111d9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/canonicalize-dynamic-unpack-tile.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt %s --inline -canonicalize="test-convergence" -split-input-file | FileCheck %s --check-prefixes=CHECK
+
+// CHECK: func.func @dynamic_tile_arg_no_fold
+// CHECK-SAME:  %[[SRC:.+]]: tensor<1x3x8x1xi32>, %[[TILE:.+]]: index
+// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[SRC]] : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+// CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK-SAME:    inner_dims_pos = [0, 1]
+// CHECK-SAME:    inner_tiles = [%[[TILE]], 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+module {
+  func.func @id_index(%arg0: index) -> index {
+    return %arg0 : index
+  }
+  func.func @dynamic_tile_arg_no_fold(%arg0: tensor<1x3x8x1xi32>, %arg1: index) -> tensor<7x3xi32> {
+    %0 = tensor.empty() : tensor<7x3xi32>
+    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%arg1, 1] into %0 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+    return %unpack : tensor<7x3xi32>
+  }
+}
+
+
+// -----
+
+// CHECK-LABEL: func.func @dynamic_tile_from_inlined_mismatch_no_fold
+// CHECK-DAG:   %[[C256:.+]] = arith.constant 256 : index
+// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %{{.+}} : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+// CHECK:       %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK-SAME:    inner_dims_pos = [0, 1]
+// CHECK-SAME:    inner_tiles = [%[[C256]], 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+module {
+  func.func @get_tile() -> index {
+    %c256 = arith.constant 256 : index
+    return %c256 : index
+  }
+  func.func @dynamic_tile_from_inlined_mismatch_no_fold(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
+    %0 = call @get_tile() : () -> index
+    %1 = tensor.empty() : tensor<7x3xi32>
+    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%0, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+    return %unpack : tensor<7x3xi32>
+  }
+}
+
+
+// -----
+
+// CHECK-LABEL: func.func @constant_tile_from_inlined_match_folds
+// CHECK:       %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-NOT:   tensor.cast
+// CHECK:       %[[UNPACK:.+]] = linalg.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<1x3x8x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+module {
+  func.func @get_tile() -> index {
+    %c8 = arith.constant 8 : index
+    return %c8 : index
+  }
+  func.func @constant_tile_from_inlined_match_folds(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
+    %0 = call @get_tile() : () -> index
+    %1 = tensor.empty() : tensor<7x3xi32>
+    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%0, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+    return %unpack : tensor<7x3xi32>
+  }
+}

@github-actions

github-actions Bot commented Mar 23, 2026

Copy link
Copy Markdown

✅ With the latest revision this PR passed the C/C++ code formatter.

@hockyy hockyy force-pushed the bail-tensor-pack branch 2 times, most recently from ab62f02 to ddf0ddf Compare March 23, 2026 09:50
@rengolin

Copy link
Copy Markdown
Member

Quick check. Given the time frame (3hs) between submitting the issuer and issuing a fix, and due to the very AI looking shape of the issue itself, I have to ask: did you use AI tools to generate this PR? If so, you must make it clear, according to our AI policy.

@hockyy

hockyy commented Mar 23, 2026

Copy link
Copy Markdown
Member Author

Quick check. Given the time frame (3hs) between submitting the issuer and issuing a fix, and due to the very AI looking shape of the issue itself, I have to ask: did you use AI tools to generate this PR? If so, you must make it clear, according to our AI policy.

Hi, I'm using code completion, but still manually changing the code, testcases, and checks. I will put the assistant in my MR. Thanks

"tile size and dim size don't match!");
newMixedTileSizes.push_back(
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
std::optional<int64_t> constTile = getConstantIntValue(tile);

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fuse if with previous else for readability.

@hockyy hockyy Mar 23, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by directly using getConstantIntValue's int value d5c9b28

Comment thread mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated
Comment thread mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated
Comment thread mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir Outdated
Comment thread mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir Outdated
Comment thread mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir Outdated
Comment thread mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir Outdated
Comment thread mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Outdated
@hockyy hockyy changed the title [mlir][linalg] Bail out tensor.cast pack/unpack fold on unprovable tile sizes [mlir][linalg] Fold tensor.cast into unpack using static packed shape for inner tiles Mar 26, 2026
@hockyy hockyy requested a review from joker-eph March 26, 2026 02:48

@joker-eph joker-eph left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but please wait for another approval.

So ping on @llvm/pr-subscribers-mlir-linalg folks?

@joker-eph joker-eph changed the title [mlir][linalg] Fold tensor.cast into unpack using static packed shape for inner tiles [mlir][linalg] Fix crash when folding tensor.cast into unpack using static packed shape for inner tiles Mar 30, 2026
@joker-eph

Copy link
Copy Markdown
Contributor

ping on Linalg folks for this crash fix?
This LGTM but I was seeking another possible approval @zero9178 or @banach-space maybe?

@github-actions

github-actions Bot commented Apr 1, 2026

Copy link
Copy Markdown

🪟 Windows x64 Test Results

  • 3615 tests passed
  • 417 tests skipped

✅ The build succeeded and all tests passed.

@github-actions

github-actions Bot commented Apr 1, 2026

Copy link
Copy Markdown

🐧 Linux x64 Test Results

  • 7801 tests passed
  • 606 tests skipped

✅ The build succeeded and all tests passed.

@hockyy hockyy force-pushed the bail-tensor-pack branch from 96ef7cc to 5338861 Compare April 13, 2026 08:01
@hockyy

hockyy commented Apr 21, 2026

Copy link
Copy Markdown
Member Author

Hi, @adam-smnk @krzysz00 @Groverkss any second opinion on this MR?

@joker-eph

Copy link
Copy Markdown
Contributor

Ping me tomorrow and I'll merge if we can't get second opinion from Linalg folks by then.

@krzysz00 krzysz00 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any issues but I don't know this code. This is a very weak approve

@hockyy

hockyy commented Apr 22, 2026

Copy link
Copy Markdown
Member Author

Hello, @joker-eph let's merge this one

@joker-eph

Copy link
Copy Markdown
Contributor

You got commit access now I believe? I'll let you do the honors then :)

@hockyy hockyy merged commit a5f7f49 into llvm:main Apr 22, 2026
10 checks passed
linuxlonelyeagle pushed a commit to linuxlonelyeagle/llvm-project that referenced this pull request Apr 23, 2026
…tatic packed shape for inner tiles (llvm#188000)

This change fixes llvm#187975 and llvm#188405, a crash in Linalg tensor-cast
folding for pack/unpack when tile sizes are dynamic or otherwise not
provably constant.

Previously, canonicalization could reach getNewMixedTileSizes and
unconditionally access `getConstantIntValue(tile).value()`. For dynamic
tile operands, that value can be absent, causing
`std::bad_optional_access/assert` aborts.

When folding `tensor.cast` into `linalg.unpack` (and the same helper is
used for linalg.pack), mixed inner tile sizes are updated from the
refined packed tensor type. Every static trailing packed dimension gets
a matching static tile attribute, replacing SSA tile values and
overwriting tile constants that disagreed with that type.

Dynamic packed dimensions still keep the original tile operands.

Assisted-by: CLion code completion
yingopq pushed a commit to yingopq/llvm-project that referenced this pull request Apr 29, 2026
…tatic packed shape for inner tiles (llvm#188000)

This change fixes llvm#187975 and llvm#188405, a crash in Linalg tensor-cast
folding for pack/unpack when tile sizes are dynamic or otherwise not
provably constant.

Previously, canonicalization could reach getNewMixedTileSizes and
unconditionally access `getConstantIntValue(tile).value()`. For dynamic
tile operands, that value can be absent, causing
`std::bad_optional_access/assert` aborts.

When folding `tensor.cast` into `linalg.unpack` (and the same helper is
used for linalg.pack), mixed inner tile sizes are updated from the
refined packed tensor type. Every static trailing packed dimension gets
a matching static tile attribute, replacing SSA tile values and
overwriting tile constants that disagreed with that type.

Dynamic packed dimensions still keep the original tile operands.

Assisted-by: CLion code completion
KHicketts pushed a commit to KHicketts/llvm-project that referenced this pull request Apr 30, 2026
…tatic packed shape for inner tiles (llvm#188000)

This change fixes llvm#187975 and llvm#188405, a crash in Linalg tensor-cast
folding for pack/unpack when tile sizes are dynamic or otherwise not
provably constant.

Previously, canonicalization could reach getNewMixedTileSizes and
unconditionally access `getConstantIntValue(tile).value()`. For dynamic
tile operands, that value can be absent, causing
`std::bad_optional_access/assert` aborts.

When folding `tensor.cast` into `linalg.unpack` (and the same helper is
used for linalg.pack), mixed inner tile sizes are updated from the
refined packed tensor type. Every static trailing packed dimension gets
a matching static tile attribute, replacing SSA tile values and
overwriting tile constants that disagreed with that type.

Dynamic packed dimensions still keep the original tile operands.

Assisted-by: CLion code completion
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

5 participants