[mlir][tosa] Don't fold mul with zero lhs/rhs if resulting type is dynamic#153420
[mlir][tosa] Don't fold mul with zero lhs/rhs if resulting type is dynamic#153420
Conversation
|
@llvm/pr-subscribers-mlir Author: Sayan Saha (sahas3) ChangesCanonicalizing the following IR: resulted in a crash from the folder for Full diff: https://github.com/llvm/llvm-project/pull/153420.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..fce61f27ca3ea 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1120,13 +1120,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
}
if (rhsTy == resultTy) {
- if (isSplatZero(resultETy, lhsAttr))
+ if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
+ // constant values can only be resized if resulting type is static
return lhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, lhsAttr, shift))
return rhs;
}
if (lhsTy == resultTy) {
- if (isSplatZero(resultETy, rhsAttr))
+ if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
return rhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, rhsAttr, shift))
return lhs;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 5150ee36e9e5e..930bb9fe96811 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -565,6 +565,33 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
// -----
+// CHECK-LABEL: @mul_zero_dynamic_nofold
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK: %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+// CHECK: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: %[[MUL:.*]] = tosa.mul %[[ARG0]], %[[ZERO]], %[[SHIFT]] : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+// CHECK: return %[[MUL]]
+func.func @mul_zero_dynamic_nofold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+ %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+ %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+ return %2 : tensor<?x17xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_dynamic_fold
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK: return %[[ARG0]]
+func.func @mul_one_dynamic_fold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+ %0 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+ %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+ return %2 : tensor<?x17xf32>
+}
+
+// -----
+
// CHECK-LABEL: @select_same_value
func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
|
@llvm/pr-subscribers-mlir-tosa Author: Sayan Saha (sahas3) ChangesCanonicalizing the following IR: resulted in a crash from the folder for Full diff: https://github.com/llvm/llvm-project/pull/153420.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..fce61f27ca3ea 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1120,13 +1120,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
}
if (rhsTy == resultTy) {
- if (isSplatZero(resultETy, lhsAttr))
+ if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
+ // constant values can only be resized if resulting type is static
return lhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, lhsAttr, shift))
return rhs;
}
if (lhsTy == resultTy) {
- if (isSplatZero(resultETy, rhsAttr))
+ if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
return rhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, rhsAttr, shift))
return lhs;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 5150ee36e9e5e..930bb9fe96811 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -565,6 +565,33 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
// -----
+// CHECK-LABEL: @mul_zero_dynamic_nofold
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK: %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+// CHECK: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: %[[MUL:.*]] = tosa.mul %[[ARG0]], %[[ZERO]], %[[SHIFT]] : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+// CHECK: return %[[MUL]]
+func.func @mul_zero_dynamic_nofold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+ %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+ %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+ return %2 : tensor<?x17xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_dynamic_fold
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK: return %[[ARG0]]
+func.func @mul_one_dynamic_fold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+ %0 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+ %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+ return %2 : tensor<?x17xf32>
+}
+
+// -----
+
// CHECK-LABEL: @select_same_value
func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
sjarus
left a comment
There was a problem hiding this comment.
Dynamic shape handling is nice, thanks!
Canonicalizing the following IR:
resulted in a crash
from the folder for
tosa::mulsince the zero value was being reshaped to?x17size which isn't supported. AFAIK,tosa.constrequires all dimensions to be static. So in this case, the fix is to not to fold the op.