Skip to content

Commit

Permalink
[Global opt] add flag to generalize matmul ops (#17877)
Browse files Browse the repository at this point in the history
Helps when the producer is a broadcast op. After adding the flag to sdxl scripts, I saw a decent decrease in the
number of dispatches.

Initially, I was trying to manually generalize+fuse broadcasts [branch
here](https://github.com/IanWood1/iree/tree/broadcast_matmul), but quinn
saw good results with just this.

---------

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Jul 12, 2024
1 parent 05dfe0b commit f0d24cd
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ iree_compiler_cc_library(
"ApplyPDLPatterns.cpp",
"ConvertConv2DToImg2Col.cpp",
"ConvertConvToChannelsLast.cpp",
"GeneralizeLinalgMatMul.cpp",
"InterpreterPass.cpp",
"MakeSingleDispatchForFunction.cpp",
"PadLinalgOps.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_cc_library(
"ApplyPDLPatterns.cpp"
"ConvertConv2DToImg2Col.cpp"
"ConvertConvToChannelsLast.cpp"
"GeneralizeLinalgMatMul.cpp"
"InterpreterPass.cpp"
"MakeSingleDispatchForFunction.cpp"
"PadLinalgOps.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler::Preprocessing {

#define GEN_PASS_DEF_GENERALIZELINALGMATMULPASS
#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export

namespace {

struct GeneralizeLinalgMatMulPass
: public iree_compiler::Preprocessing::impl::GeneralizeLinalgMatMulPassBase<
GeneralizeLinalgMatMulPass> {
using iree_compiler::Preprocessing::impl::GeneralizeLinalgMatMulPassBase<
GeneralizeLinalgMatMulPass>::GeneralizeLinalgMatMulPassBase;
void runOnOperation() override {
auto funcOp = getOperation();
SmallVector<linalg::LinalgOp> namedOpCandidates;
funcOp.walk([&](linalg::LinalgOp linalgOp) {
if (!IREE::Flow::isNonNullAndOutsideDispatch(linalgOp)) {
return;
}
if (isa_and_nonnull<linalg::MatmulOp, linalg::MatmulTransposeBOp,
linalg::BatchMatmulOp,
linalg::BatchMatmulTransposeBOp>(linalgOp)) {
namedOpCandidates.push_back(linalgOp);
}
});

IRRewriter rewriter(&getContext());

for (auto linalgOp : namedOpCandidates) {
rewriter.setInsertionPoint(linalgOp);
FailureOr<linalg::GenericOp> generalizedOp =
linalg::generalizeNamedOp(rewriter, linalgOp);
if (failed(generalizedOp)) {
linalgOp->emitOpError("failed to generalize operation");
return signalPassFailure();
}
}
}
};
} // namespace
} // namespace mlir::iree_compiler::Preprocessing
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Preprocessing/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,12 @@ def TransposeMatmulPass : Pass<"iree-preprocessing-transpose-matmul-pass"> {
];
}

def GeneralizeLinalgMatMulPass :
InterfacePass<"iree-preprocessing-generalize-linalg-matmul-experimental", "mlir::FunctionOpInterface"> {
let summary = "Convert linalg matmul ops to linalg.generics.";
let dependentDialects = [
"mlir::linalg::LinalgDialect",
];
}

#endif // IREE_PREPROCESSING_COMMON_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_lit_test_suite(
[
"conv2d_to_img2col.mlir",
"conv_to_channels_last.mlir",
"generalize_linalg_matmul.mlir",
"make_single_dispatch_for_function.mlir",
"pad_linalg_ops.mlir",
"pad_to_intrinsics_mfma.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_lit_test_suite(
SRCS
"conv2d_to_img2col.mlir"
"conv_to_channels_last.mlir"
"generalize_linalg_matmul.mlir"
"make_single_dispatch_for_function.mlir"
"pad_linalg_ops.mlir"
"pad_to_intrinsics_mfma.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" --verify-each --split-input-file %s | FileCheck %s

util.func public @generalize_matmul(%arg0: tensor<1x128x128xf32>, %arg1: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> {
%0 = tensor.empty() : tensor<1x128x128xf32>
%1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x128x128xf32>, tensor<1x128x128xf32>) outs(%0 : tensor<1x128x128xf32>) -> tensor<1x128x128xf32>
util.return %1 : tensor<1x128x128xf32>
}

// CHECK-LABEL: util.func public @generalize_matmul
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x128x128xf32>, %[[ARG1:.+]]: tensor<1x128x128xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: %[[ARG0]], %[[ARG1]]

0 comments on commit f0d24cd

Please sign in to comment.