diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 11435b4524a2f..8d511aeec5424 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5129,7 +5129,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // Return true if we have a zero-value tile. auto hasZeros = [&](ArrayRef tiles) { - return llvm::any_of(tiles, isZeroInteger); + return llvm::any_of(tiles, [](OpFoldResult tile) { + return isa(tile) && isZeroInteger(tile); + }); }; // Verify that the source and destination are ranked types. @@ -5513,13 +5515,15 @@ bool PackOp::requirePaddingValue(ArrayRef inputShape, if (ShapedType::isDynamic(inputShape[pos])) continue; std::optional constantTile = getConstantIntValue(tileSize); - if (!constantTile) { if (ShapedType::isStatic(outputTileSizes[pos]) && (inputShape[pos] % outputTileSizes[pos] != 0)) return true; - } else if (inputShape[pos] % (*constantTile) != 0) { - return true; + } else { + assert(*constantTile != 0 && "static tile size can't be zero"); + if (inputShape[pos] % (*constantTile) != 0) { + return true; + } } } return false; @@ -5545,6 +5549,7 @@ bool PackOp::requirePaddingValueStrict(ArrayRef inputShape, std::optional constantTile = getConstantIntValue(tileSize); if (!constantTile) return true; + assert(*constantTile != 0 && "static tile size can't be zero"); if (inputShape[pos] % (*constantTile) != 0) return true; } @@ -6014,6 +6019,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern { // Get the updated mixed-tile-sizes attribute. SmallVector newMixedTileSizes = getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles()); + if (llvm::any_of(newMixedTileSizes, isZeroInteger)) + return failure(); // Clone op. // TODO: Strictly speaking, discardable attributes should be _discarded_ at diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 77c1c3da17166..0c5a1c6108ae3 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -511,9 +511,9 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) { // ----- -// CHECK-LABEL: func @no_fold_fill_like_memref +// CHECK-LABEL: func @negative_fold_fill_like_memref // CHECK-NEXT: linalg.generic -func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) { +func.func @negative_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) { linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -527,9 +527,9 @@ func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) // ----- -// CHECK-LABEL: func @no_fold_fill_like_tensor +// CHECK-LABEL: func @negative_fold_fill_like_tensor // CHECK-NEXT: linalg.generic -func.func @no_fold_fill_like_tensor(%in_out : tensor<4x16xf32>, %fill_val : f32) -> tensor<4x16xf32> { +func.func @negative_fold_fill_like_tensor(%in_out : tensor<4x16xf32>, %fill_val : f32) -> tensor<4x16xf32> { %result = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -589,8 +589,8 @@ func.func @fold_dynamic_pad_fill(%empty: tensor<8x?x16x32xf32>, %low0: index, %l // ----- -// CHECK-LABEL: func @no_fold_pad_fill_value_mismatch -func.func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> { +// CHECK-LABEL: func @negative_fold_pad_fill_value_mismatch +func.func @negative_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> { %f0 = arith.constant 0.0 : f32 %f1 = arith.constant 1.0 : f32 %empty = tensor.empty() : tensor<400x273xf32> @@ -1451,6 +1451,25 @@ func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x // ----- +// A dynamic tile size must not be folded into a static dimension, even when +// the dynamic value is a constant zero at the point of canonicalization. +// CHECK-LABEL: func.func @negative_fold_pack_zero_tile +// CHECK: %[[C0:.*]] = arith.constant 0 +// CHECK: linalg.pack {{.*}}inner_tiles = [%[[C0]], 1] +func.func @negative_fold_pack_zero_tile(%A: tensor<7x16xi32>) -> tensor<1x16x?x1xi32> { + %pad_val = arith.constant 123 : i32 + %tile_size = arith.constant 0 : index + %empty = tensor.empty(%tile_size) : tensor<1x16x?x1xi32> + %pack = linalg.pack %A + padding_value(%pad_val : i32) + inner_dims_pos = [0, 1] + inner_tiles = [%tile_size, 1] + into %empty : tensor<7x16xi32> -> tensor<1x16x?x1xi32> + return %pack : tensor<1x16x?x1xi32> +} + +// ----- + // CHECK-LABEL: func @fold_padding_value_pack_constant_splat // CHECK-NOT: linalg.pack // CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> @@ -1466,10 +1485,10 @@ func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) // ----- -// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat +// CHECK-LABEL: func @negative_fold_padding_value_pack_constant_splat // CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32> // CHECK: linalg.pack -func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { +func.func @negative_fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 0.0 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst @@ -1538,13 +1557,13 @@ func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor, %arg1: index) -> tensor<32x7x?x16x1xf32> { +func.func @negative_infer_pack_shape(%arg0: tensor, %arg1: index) -> tensor<32x7x?x16x1xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.empty(%arg1) : tensor<32x7x?x16x1xf32> %pack = linalg.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor -> tensor<32x7x?x16x1xf32> return %pack : tensor<32x7x?x16x1xf32> } -// CHECK-LABEL: func.func @no_infer_pack_shape +// CHECK-LABEL: func.func @negative_infer_pack_shape // CHECK-NOT: tensor.cast // ----- @@ -1650,13 +1669,13 @@ func.func @infer_src_shape_unpack(%src: tensor, %dest: tensor<30 // ----- -func.func @no_infer_unpack_shape(%arg1: tensor<32x7x?x16x1xf32>, %arg2: index) -> tensor { +func.func @negative_infer_unpack_shape(%arg1: tensor<32x7x?x16x1xf32>, %arg2: index) -> tensor { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.empty(%arg2) : tensor %unpack = linalg.unpack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<32x7x?x16x1xf32> -> tensor return %unpack : tensor } -// CHECK-LABEL: func.func @no_infer_unpack_shape +// CHECK-LABEL: func.func @negative_infer_unpack_shape // CHECK-NOT: tensor.cast // ----- @@ -1724,10 +1743,10 @@ func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> // ----- -// CHECK: func.func @unpack_pack_with_padding_no_canonicalization( +// CHECK: func.func @negative_unpack_pack_with_padding_no_canonicalization( // CHECK: linalg.pack // CHECK: linalg.unpack -func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>) -> tensor<224x512xbf16> { +func.func @negative_unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>) -> tensor<224x512xbf16> { %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16> %tensor_empty1 = tensor.empty() : tensor<224x512xbf16> %packed = linalg.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16> @@ -1982,7 +2001,7 @@ func.func @fold_extract_slice_into_unpack_slicing_dim_1(%src : tensor<28x2x1x16x // The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2. -func.func @no_fold_extract_slice_into_unpack_artificial_padding(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x16x15xf32> { +func.func @negative_fold_extract_slice_into_unpack_artificial_padding(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x16x15xf32> { %unpack = linalg.unpack %src inner_dims_pos = [1, 2] inner_tiles = [16, 16] @@ -1991,13 +2010,13 @@ func.func @no_fold_extract_slice_into_unpack_artificial_padding(%src : tensor<28 [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32> return %extracted_slice : tensor<28x16x15xf32> } -// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding +// CHECK-LABEL: func @negative_fold_extract_slice_into_unpack_artificial_padding // CHECK: linalg.unpack // CHECK: tensor.extract_slice // ----- -func.func @no_fold_extract_slice_into_unpack_dynamic( +func.func @negative_fold_extract_slice_into_unpack_dynamic( %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index ) -> tensor<28x28x?xf32> { %unpack = linalg.unpack %src @@ -2009,13 +2028,13 @@ func.func @no_fold_extract_slice_into_unpack_dynamic( [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32> return %extracted_slice : tensor<28x28x?xf32> } -// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic +// CHECK-LABEL: func @negative_fold_extract_slice_into_unpack_dynamic // CHECK: linalg.unpack // CHECK: tensor.extract_slice // ----- -func.func @no_fold_extract_slice_into_unpack_rank_reducing( +func.func @negative_fold_extract_slice_into_unpack_rank_reducing( %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32> ) -> tensor<28xf32> { %unpack = linalg.unpack %src @@ -2028,7 +2047,7 @@ func.func @no_fold_extract_slice_into_unpack_rank_reducing( return %extracted_slice : tensor<28xf32> } -// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing +// CHECK-LABEL: func @negative_fold_extract_slice_into_unpack_rank_reducing // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32> // CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32> // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]] @@ -2038,7 +2057,7 @@ func.func @no_fold_extract_slice_into_unpack_rank_reducing( // ----- -func.func @no_fold_extract_slice_into_unpack_non_zero_offset( +func.func @negative_fold_extract_slice_into_unpack_non_zero_offset( %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32> ) -> tensor<28x28xf32> { %unpack = linalg.unpack %src @@ -2051,7 +2070,7 @@ func.func @no_fold_extract_slice_into_unpack_non_zero_offset( return %extracted_slice : tensor<28x28xf32> } -// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset +// CHECK-LABEL: func @negative_fold_extract_slice_into_unpack_non_zero_offset // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32> // CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32> // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]] @@ -2062,7 +2081,7 @@ func.func @no_fold_extract_slice_into_unpack_non_zero_offset( // ----- // Must not fold because extract_slice cuts the 0'th dimension from 30 to 28. -func.func @no_fold_extract_slice_into_unpack_slice_over_non_tiled_dim( +func.func @negative_fold_extract_slice_into_unpack_slice_over_non_tiled_dim( %src : tensor<30x2x16xf32>, %dest : tensor<30x32xf32> ) -> tensor<28x28xf32> { %unpack = linalg.unpack %src @@ -2074,7 +2093,7 @@ func.func @no_fold_extract_slice_into_unpack_slice_over_non_tiled_dim( return %extracted_slice : tensor<28x28xf32> } -// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_slice_over_non_tiled_dim +// CHECK-LABEL: func @negative_fold_extract_slice_into_unpack_slice_over_non_tiled_dim // CHECK-SAME: %[[SRC:.+]]: tensor<30x2x16xf32> // CHECK-SAME: %[[DEST:.+]]: tensor<30x32xf32> // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]] @@ -2085,7 +2104,7 @@ func.func @no_fold_extract_slice_into_unpack_slice_over_non_tiled_dim( // ----- // Must not fold because extract_slice's effect on the 0'th dimension is unknown. -func.func @no_fold_extract_slice_into_unpack_slice_over_dynamic_dim( +func.func @negative_fold_extract_slice_into_unpack_slice_over_dynamic_dim( %src : tensor, %dest : tensor, %size : index ) -> tensor { %unpack = linalg.unpack %src @@ -2097,7 +2116,7 @@ func.func @no_fold_extract_slice_into_unpack_slice_over_dynamic_dim( return %extracted_slice : tensor } -// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_slice_over_dynamic_dim +// CHECK-LABEL: func @negative_fold_extract_slice_into_unpack_slice_over_dynamic_dim // CHECK-SAME: %[[SRC:.+]]: tensor // CHECK-SAME: %[[DEST:.+]]: tensor // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]] @@ -2140,10 +2159,10 @@ func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> { // ----- // Test that pack/unpack canonicalization is disabled for memref versions. -// CHECK-LABEL: func.func @pack_unpack_memref_no_canonicalization +// CHECK-LABEL: func.func @negative_pack_unpack_memref_no_canonicalization // CHECK: linalg.pack // CHECK: linalg.unpack -func.func @pack_unpack_memref_no_canonicalization(%source: memref<128x256xf32>, %packed: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) { +func.func @negative_pack_unpack_memref_no_canonicalization(%source: memref<128x256xf32>, %packed: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) { linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %packed : memref<128x256xf32> -> memref<16x8x8x32xf32> linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32> return @@ -2152,10 +2171,10 @@ func.func @pack_unpack_memref_no_canonicalization(%source: memref<128x256xf32>, // ----- // Test that unpack/pack canonicalization is disabled for memref versions. -// CHECK-LABEL: func.func @unpack_pack_memref_no_canonicalization +// CHECK-LABEL: func.func @negative_unpack_pack_memref_no_canonicalization // CHECK: linalg.unpack // CHECK: linalg.pack -func.func @unpack_pack_memref_no_canonicalization(%packed: memref<16x8x8x32xf32>, %unpacked: memref<128x256xf32>, %dest: memref<16x8x8x32xf32>) { +func.func @negative_unpack_pack_memref_no_canonicalization(%packed: memref<16x8x8x32xf32>, %unpacked: memref<128x256xf32>, %dest: memref<16x8x8x32xf32>) { linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %unpacked : memref<16x8x8x32xf32> -> memref<128x256xf32> linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<128x256xf32> -> memref<16x8x8x32xf32> return