diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp index a7392a3d64a9..f64e419ef906 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "iree/compiler/Dialect/Util/Analysis/Explorer.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -26,6 +27,71 @@ namespace mlir::iree_compiler::IREE::Flow { #define GEN_PASS_DEF_FOLDUNITEXTENTDIMSPASS #include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc" +//===----------------------------------------------------------------------===// +// Pass helpers +//===----------------------------------------------------------------------===// + +static LogicalResult +foldUnitDimsOnGlobal(IRRewriter &rewriter, IREE::Util::GlobalOpInterface global, + SmallVector loadOps, + SmallVector storeOps, + SymbolTable moduleSymbols) { + // Create a new transformed GlobalOp. + SmallVector newShape; + auto globalType = cast(global.getGlobalType()); + for (auto size : globalType.getShape()) { + if (size != 1) { + newShape.push_back(size); + } + } + auto newGlobalType = globalType.clone(newShape); + auto initialValue = global.getGlobalInitialValue(); + // TODO: Handle non-uninitialized cases. + auto uninitializedAttr = + llvm::dyn_cast_if_present(initialValue); + if (initialValue && !uninitializedAttr) + return success(); + TypedAttr newInitialValue; + if (initialValue) { + newInitialValue = IREE::Util::UninitializedAttr::get(rewriter.getContext(), + newGlobalType); + } + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(global); + auto newGlobal = + clone(rewriter, global, global->getResultTypes(), global->getOperands()); + newGlobal.setGlobalType(newGlobalType); + newGlobal.setGlobalInitialValue(newInitialValue); + + // Rewrite loads and stores to use the new global. + auto expandShapeReInds = + getReassociationIndicesForReshape(globalType, newGlobalType); + if (!expandShapeReInds) { + return failure(); + } + + for (auto load : loadOps) { + rewriter.setInsertionPoint(load); + auto newLoad = clone(rewriter, load, {newGlobalType}, load->getOperands()); + newLoad.setGlobalAttr(FlatSymbolRefAttr::get(newGlobal.getGlobalName())); + rewriter.replaceOpWithNewOp( + load, globalType, newLoad->getResult(0), expandShapeReInds.value()); + } + for (auto store : storeOps) { + rewriter.setInsertionPoint(store); + Value collapse = rewriter.create( + store.getLoc(), newGlobalType, store->getOperand(0), + expandShapeReInds.value()); + auto newStore = + clone(rewriter, store, store->getResultTypes(), store->getOperands()); + newStore.setGlobalAttr(FlatSymbolRefAttr::get(newGlobal.getGlobalName())); + newStore.setStoredGlobalValue(collapse); + rewriter.eraseOp(store); + } + rewriter.eraseOp(global); + return success(); +} + namespace { struct FoldUnitExtentDimsPass : public IREE::Flow::impl::FoldUnitExtentDimsPassBase< @@ -35,8 +101,35 @@ struct FoldUnitExtentDimsPass } // namespace void FoldUnitExtentDimsPass::runOnOperation() { - Operation *funcOp = getOperation(); + auto funcOp = getOperation(); MLIRContext *context = &getContext(); + + Explorer explorer(funcOp, TraversalAction::RECURSE); + explorer.initialize(); + IRRewriter rewriter(context); + SymbolTable moduleSymbols(funcOp); + + // Fold unit dims of GlobalOpInterface ops. + explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) { + IREE::Util::GlobalOpInterface global = globalInfo->op; + auto tensorType = dyn_cast(global.getGlobalType()); + if (!tensorType || !global.isGlobalPrivate() || !global.isGlobalMutable()) { + return; + } + if (llvm::none_of(tensorType.getShape(), + [](int64_t size) { return size == 1; })) { + return; + } + SmallVector loadOps = + llvm::to_vector(globalInfo->getLoads()); + SmallVector storeOps = + llvm::to_vector(globalInfo->getStores()); + if (failed(foldUnitDimsOnGlobal(rewriter, global, loadOps, storeOps, + moduleSymbols))) { + return signalPassFailure(); + } + }); + RewritePatternSet foldUnitDimsPatterns(context); linalg::ControlDropUnitDims options; auto defaultFn = options.controlFn; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td index a7c67b98a55c..2654bed51dfa 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td @@ -132,7 +132,7 @@ def DeduplicateExecutablesPass : } def FoldUnitExtentDimsPass : - InterfacePass<"iree-flow-fold-unit-extent-dims", "mlir::FunctionOpInterface"> { + Pass<"iree-flow-fold-unit-extent-dims", "mlir::ModuleOp"> { let summary = "Fold unit extent dimension of operations."; let description = [{ Imports upstream patterns to fold unit extent dims but with IREE control. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir index cc9f68499bd1..611c2da1b437 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fold-unit-extent-dims))" %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(iree-flow-fold-unit-extent-dims)" %s --split-input-file | FileCheck %s util.func public @no_fold_unit_dims_in_dispatches(%arg0 : tensor<1x1x10xf32>) -> tensor<1x1x10xf32> { %0 = tensor.empty() : tensor<1x1x10xf32> @@ -21,3 +21,73 @@ util.func public @no_fold_unit_dims_in_dispatches(%arg0 : tensor<1x1x10xf32>) -> // CHECK-SAME: ins(%[[ARG0]] : tensor<1x1x10xf32>) // CHECK: flow.return %[[GENERIC]] // CHECK: util.return %[[DISPATCH]] + + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (0, 0)> +module @fold_unit_dims { + util.global private mutable @global {inlining_policy = #util.inline.never} = #util.uninitialized : tensor<1x32x1x1x64xf32> + util.global private mutable @unit_global = #util.uninitialized : tensor<1x1xf32> + util.func public @fold_global_unit_dims() -> tensor<32x64xf32> { + %global = util.global.load @global : tensor<1x32x1x1x64xf32> + %unit_global = util.global.load @unit_global : tensor<1x1xf32> + %collapsed = tensor.collapse_shape %global [[0, 1], [2, 3, 4]] : tensor<1x32x1x1x64xf32> into tensor<32x64xf32> + %0 = tensor.empty() : tensor<32x64xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %unit_global : tensor<32x64xf32>, tensor<1x1xf32>) outs(%0 : tensor<32x64xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.addf %in, %in_0 : f32 + linalg.yield %2 : f32 + } -> tensor<32x64xf32> + %expanded = tensor.expand_shape %1 [[0, 1], [2, 3, 4]] output_shape[1, 32, 1, 1, 64] : tensor<32x64xf32> into tensor<1x32x1x1x64xf32> + util.global.store %expanded, @global : tensor<1x32x1x1x64xf32> + util.return %1 : tensor<32x64xf32> + } +} + +// CHECK: module @fold_unit_dims +// CHECK: util.global private mutable @[[GLOBAL:.+]] {inlining_policy = #util.inline.never} = #util.uninitialized : tensor<32x64xf32> +// CHECK: util.global private mutable @[[UNIT_GLOBAL:.+]] = #util.uninitialized : tensor +// CHECK: util.func public @fold_global_unit_dims +// CHECK: %[[LOAD0:.+]] = util.global.load @[[GLOBAL]] : tensor<32x64xf32> +// CHECK: %[[LOAD1:.+]] = util.global.load @[[UNIT_GLOBAL]] : tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LOAD0]], %[[LOAD1]] +// CHECK: util.global.store %[[GENERIC]], @[[GLOBAL]] : tensor<32x64xf32> +// CHECK: util.return %[[GENERIC]] + +// ----- + +module @no_fold_immutable { + util.global private @global : tensor<1x32x1x1x64xf32> + util.func public @no_fold_global_unit_dims() -> tensor<32x64xf32> { + %global = util.global.load @global : tensor<1x32x1x1x64xf32> + %collapsed = tensor.collapse_shape %global [[0, 1], [2, 3, 4]] : tensor<1x32x1x1x64xf32> into tensor<32x64xf32> + util.return %collapsed : tensor<32x64xf32> + } +} + +// CHECK: module @no_fold_immutable +// CHECK: util.global private @[[GLOBAL:.+]] : tensor<1x32x1x1x64xf32> +// CHECK: util.func public @no_fold_global_unit_dims +// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]] +// CHECK: util.return %[[COLLAPSE]] + +// ----- + +module @no_fold_public { + util.global public mutable @global : tensor<1x32x1x1x64xf32> + util.func public @no_fold_global_unit_dims() -> tensor<32x64xf32> { + %global = util.global.load @global : tensor<1x32x1x1x64xf32> + %collapsed = tensor.collapse_shape %global [[0, 1], [2, 3, 4]] : tensor<1x32x1x1x64xf32> into tensor<32x64xf32> + util.return %collapsed : tensor<32x64xf32> + } +} + +// CHECK: module @no_fold_public +// CHECK: util.global public mutable @[[GLOBAL:.+]] : tensor<1x32x1x1x64xf32> +// CHECK: util.func public @no_fold_global_unit_dims +// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]] diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index ddabd1cc3d0f..bab60f67bb91 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -122,8 +122,10 @@ void buildGlobalOptimizationPassPipeline( // dims as the unit dim folding pass updates indexing maps and is better // at working with generics. By this point we have already done any // specialized raising and the op names are no longer useful. - .addPass(createGeneralizeLinalgNamedOpsPass) - .addPass(IREE::Flow::createFoldUnitExtentDimsPass) + .addPass(createGeneralizeLinalgNamedOpsPass); + + mainPassManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass()); + FunctionLikeNest(mainPassManager) .addPredicatedPass(clEnableFuseSiluHorizontalMatmul, createFuseSiluHorizontalMatmulPass) .addPass([&]() { diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.cpp b/compiler/src/iree/compiler/Preprocessing/Passes.cpp index 77f6da5c1e22..c149342051d6 100644 --- a/compiler/src/iree/compiler/Preprocessing/Passes.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Passes.cpp @@ -102,8 +102,8 @@ buildTransposeConvolutionPassPipeline(OpPassManager &passManager, .addPass(GlobalOptimization::createDetachElementwiseFromNamedOpsPass) .addPass(mlir::createLinalgNamedOpConversionPass) .addPass(GlobalOptimization::createConvert1X1FilterConv2DToMatmulPass) - .addPass(createConvertConvToChannelsLastPass) - .addPass(IREE::Flow::createFoldUnitExtentDimsPass); + .addPass(createConvertConvToChannelsLastPass); + passManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass()); passManager.addPass(createCanonicalizerPass()); passManager.addPass(createCSEPass()); }