diff --git a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp index 5f30902c7e2f..b74c435f0b95 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" @@ -24,6 +25,10 @@ namespace { struct GeneralizeLinalgNamedOpsPass : public GeneralizeLinalgNamedOpsBase { + GeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps) { + this->generalizeLinalgMatmulOps = generalizeLinalgMatmulOps; + } + void runOnOperation() override; }; } // namespace @@ -45,6 +50,11 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() { linalgOp.getOperation())) { namedOpCandidates.push_back(linalgOp); } + if (generalizeLinalgMatmulOps && + isa_and_nonnull( + linalgOp)) { + namedOpCandidates.push_back(linalgOp); + } }); IRRewriter rewriter(&getContext()); @@ -60,8 +70,9 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() { } std::unique_ptr> -createGeneralizeLinalgNamedOpsPass() { - return std::make_unique(); +createGeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps) { + return std::make_unique( + generalizeLinalgMatmulOps); } } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index bab60f67bb91..9bbfa7e7baf2 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -23,6 +23,10 @@ static llvm::cl::opt clEnableQuantizedMatmulReassociation( llvm::cl::desc( "Enables reassociation of quantized matmul ops (experimental)."), llvm::cl::init(false)); +static llvm::cl::opt + clGeneralizeLinalgMatmulOps("enable-generalize-linalg-matmul-ops", + llvm::cl::desc("Generalize linalg MatMul ops"), + llvm::cl::init(false)); static llvm::cl::opt clEnableFuseSiluHorizontalMatmul( "iree-global-opt-enable-fuse-silu-horizontal-matmul", llvm::cl::desc( @@ -122,7 +126,9 @@ void buildGlobalOptimizationPassPipeline( // dims as the unit dim folding pass updates indexing maps and is better // at working with generics. By this point we have already done any // specialized raising and the op names are no longer useful. - .addPass(createGeneralizeLinalgNamedOpsPass); + .addPass([&]() { + return createGeneralizeLinalgNamedOpsPass(clGeneralizeLinalgMatmulOps); + }); mainPassManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass()); FunctionLikeNest(mainPassManager) @@ -178,7 +184,9 @@ void buildGlobalOptimizationPassPipeline( } // Generalize transposes and any other remaining named linalg ops that can // now be represented as generics. - FunctionLikeNest(mainPassManager).addPass(createGeneralizeLinalgNamedOpsPass); + FunctionLikeNest(mainPassManager).addPass([&]() { + return createGeneralizeLinalgNamedOpsPass(clGeneralizeLinalgMatmulOps); + }); // Hoist loop invariants (e.g. from scf loops) with zero-trip-check. FunctionLikeNest(mainPassManager) diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h index 6ddbeaba1cb3..8156faa470b3 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h @@ -86,7 +86,7 @@ createFuseSiluHorizontalMatmulPass(); /// Generalizes some named Linalg ops into `linalg.generic` operations since the /// compiler can handle that better. std::unique_ptr> -createGeneralizeLinalgNamedOpsPass(); +createGeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps = false); /// Infers and inserts util.numeric.optional_narrow ops at points that may be /// beneficial. diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td index 0f3bcd336229..c80a0612154f 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td @@ -105,6 +105,10 @@ def GeneralizeLinalgNamedOps : InterfacePass<"iree-global-opt-generalize-linalg-named-ops", "mlir::FunctionOpInterface"> { let summary = "Convert some Linalg named ops into linalg.generics."; let constructor = "mlir::iree_compiler::GlobalOptimization::createGeneralizeLinalgNamedOpsPass()"; + let options = [ + Option<"generalizeLinalgMatmulOps", "enable-generalize-linalg-matmul-ops", "bool", + /*default=*/"false", "Generalize linalg batch MatMul ops">, + ]; } def InferNumericNarrowing : diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir index 5111152b7b0d..0c371df3ee58 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops{enable-generalize-linalg-matmul-ops=true}))" --split-input-file %s | FileCheck %s util.func public @generalize_op(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index @@ -34,3 +34,16 @@ util.func public @no_generalize_op_within_dispatch(%arg0 : tensor, %arg // CHECK: %[[ADD:.+]] = linalg.add // CHECK: flow.return %[[ADD]] // CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func public @generalize_matmul(%arg0: tensor<1x128x128xf32>, %arg1: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { + %0 = tensor.empty() : tensor<1x128x128xf32> + %1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x128x128xf32>, tensor<1x128x128xf32>) outs(%0 : tensor<1x128x128xf32>) -> tensor<1x128x128xf32> + util.return %1 : tensor<1x128x128xf32> +} + +// CHECK-LABEL: util.func public @generalize_matmul +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x128x128xf32>, %[[ARG1:.+]]: tensor<1x128x128xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: %[[ARG0]], %[[ARG1]]