TosaToLinalgNamed: add option to prefer HWCF kernel layout for Conv2D ops.#70482
Merged
TosaToLinalgNamed: add option to prefer HWCF kernel layout for Conv2D ops.#70482
Conversation
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (bjacob) ChangesSwitching to FHWC happened in #68304 and is fine in itself but caused downstream performance regression iree-org/iree#15296 (comment) , so this PR makes this optional. Full diff: https://github.com/llvm/llvm-project/pull/70482.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f05e5a8ae667dab..336f0d3af951b9a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1126,6 +1126,12 @@ def TosaToLinalgNamed
Linalg named operations.
}];
+ let options = [
+ Option<"preferConv2DKernelLayoutHWCF", "prefer-conv2d-kernel-layout-hwcf",
+ "bool", /*default=*/"false",
+ "Prefer generating linalg.conv_2d_nhwc_hwcf over linalg.conv_2d_nhwc_fhwc">
+ ];
+
let constructor = "tosa::createTosaToLinalgNamed()";
}
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index b4c4eb8651a6f00..7497a716e048d95 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -26,7 +26,8 @@ namespace mlir {
namespace tosa {
std::unique_ptr<Pass> createTosaToLinalg();
-std::unique_ptr<Pass> createTosaToLinalgNamed();
+std::unique_ptr<Pass> createTosaToLinalgNamed(
+ const TosaToLinalgNamedOptions &options = TosaToLinalgNamedOptions());
/// Populates passes to convert from TOSA to Linalg on buffers. At the end of
/// the pass, the function will only contain linalg ops or standard ops if the
@@ -34,6 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
/// benchmarking performance improvements from the canonicalizations.
void addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
+ const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions =
+ TosaToLinalgNamedOptions(),
// Note: Default to 'none' level unless otherwise specified.
tosa::TosaValidationOptions const &validationOptions = {
tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
@@ -45,8 +48,12 @@ void registerTosaToLinalgPipelines();
/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
+enum class Conv2DKernelLayout { FHWC, HWCF };
+
/// Populates conversion passes from TOSA dialect to Linalg named operations.
-void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToLinalgNamedConversionPatterns(
+ RewritePatternSet *patterns,
+ Conv2DKernelLayout conv2DKernelLayout = Conv2DKernelLayout::FHWC);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index ee8f52deadbd152..ae0b58acfd295b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -26,6 +26,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <numeric>
+#include <type_traits>
using namespace mlir;
using namespace mlir::tosa;
@@ -248,6 +249,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
pad.resize(pad.size() + 2, 0);
input = applyPad(loc, input, pad, zeroAttr, rewriter);
+ if (4 == inputTy.getRank()) {
+ // For 2D convolutions, we need to check if the target convolution op
+ // wants a HWCF kernel layout.
+ bool wantHwcf =
+ isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+ if (wantHwcf) {
+ // Transpose the kernel to match dimension ordering of the linalg
+ // convolution operation.
+ // TODO(suderman): See if this can be efficiently folded - check whether
+ // the input is used anywhere else, if not fold the constant.
+ SmallVector<int64_t> weightPerm;
+ for (int i = 1; i < resultTy.getRank(); i++)
+ weightPerm.push_back(i);
+ weightPerm.push_back(0);
+
+ SmallVector<int64_t> newWeightShape;
+ for (auto dim : weightPerm)
+ newWeightShape.push_back(weightShape[dim]);
+ auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+ Value weightPermValue =
+ rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ Type newWeightTy =
+ RankedTensorType::get(newWeightShape, weightTy.getElementType());
+ weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+ weightPermValue);
+ }
+ }
+
// For Conv3D transpose the kernel to match dimension ordering of the linalg
// convolution operation. Conv2D has a 1-1 mapping in linalg so better to
// map directly and then transpose later if desired.
@@ -977,10 +1007,20 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
} // namespace
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
- RewritePatternSet *patterns) {
+ RewritePatternSet *patterns, Conv2DKernelLayout conv2DKernelLayout) {
+ if (conv2DKernelLayout == Conv2DKernelLayout::FHWC) {
+ patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
+ linalg::Conv2DNhwcFhwcQOp>>(
+ patterns->getContext());
+ } else if (conv2DKernelLayout == Conv2DKernelLayout::HWCF) {
+ patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
+ linalg::Conv2DNhwcHwcfQOp>>(
+ patterns->getContext());
+ } else {
+ assert(false);
+ }
patterns->add<
// clang-format off
- ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
DepthwiseConvConverter,
MatMulConverter,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 4c941a109ed845e..e330c9cff141e40 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -37,6 +37,9 @@ namespace {
struct TosaToLinalgNamed
: public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> {
public:
+ TosaToLinalgNamed(const TosaToLinalgNamedOptions &options)
+ : impl::TosaToLinalgNamedBase<TosaToLinalgNamed>(options) {}
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
@@ -61,13 +64,18 @@ struct TosaToLinalgNamed
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
FunctionOpInterface func = getOperation();
- mlir::tosa::populateTosaToLinalgNamedConversionPatterns(&patterns);
+ tosa::Conv2DKernelLayout conv2DKernelLayout =
+ preferConv2DKernelLayoutHWCF ? tosa::Conv2DKernelLayout::HWCF
+ : tosa::Conv2DKernelLayout::FHWC;
+ tosa::populateTosaToLinalgNamedConversionPatterns(&patterns,
+ conv2DKernelLayout);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
-std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgNamed() {
- return std::make_unique<TosaToLinalgNamed>();
+std::unique_ptr<Pass>
+mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) {
+ return std::make_unique<TosaToLinalgNamed>(options);
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index a486e28c50c7129..687477810030d4c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -76,6 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
void mlir::tosa::addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
+ const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
tosa::TosaValidationOptions const &validationOptions) {
// Optional decompositions are designed to benefit linalg.
if (!options.disableTosaDecompositions)
@@ -84,7 +85,8 @@ void mlir::tosa::addTosaToLinalgPasses(
pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
- pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
+ pm.addNestedPass<func::FuncOp>(
+ tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// TODO: Remove pass that operates on const tensor and enable optionality
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
@@ -106,7 +108,9 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
"named operations.",
[](OpPassManager &pm) {
TosaToLinalgOptions tosaToLinalgOptions;
+ TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
+ tosaToLinalgNamedOptions,
/* validationOptions = */
{tosa::TosaProfileEnum::BaseInference,
/* StrictOperationSpecAlignment = */ true,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index b601bfb28a4f280..1cf7c8dee606899 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
// CHECK-LABEL: @matmul
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -363,11 +364,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
// CHECK-LABEL: @conv2d_i8
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
+ // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+ // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
// CHECK: %[[M_IN:.+]] = tensor.empty()
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty()
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+ // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
// CHECK: arith.extsi
// CHECK: arith.addi
@@ -383,11 +387,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
// CHECK-LABEL: @conv2d_f32
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
+ // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+ // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
// CHECK: %[[M_IN:.+]] = tensor.empty()
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty()
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
+ // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
// CHECK: arith.addf
// CHECK: linalg.yield
|
Contributor
Author
MaheshRavishankar
requested changes
Oct 27, 2023
Contributor
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Have one small comment. Thanks!
MaheshRavishankar
approved these changes
Oct 27, 2023
Contributor
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Looks good to me.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Switching to FHWC happened in #68304 and is fine in itself but caused downstream performance regression iree-org/iree#15296 (comment) , so this PR makes this optional.