Skip to content

Commit

Permalink
[Flow][Global Opt] Fold global unit dims (#17781)
Browse files Browse the repository at this point in the history
Currently reverting
[7884dc8](7884dc8)
to test regressions (there were problems with llama). Issue here
nod-ai/SHARK-ModelDev#756

Couldn't reproduce the issue with llama yet. It might be best to land
this since the unit dims should be folded in general, it just doesn't
play well with this model in particular.

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Jul 3, 2024
1 parent 6f26881 commit 698b75c
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -26,6 +27,71 @@ namespace mlir::iree_compiler::IREE::Flow {
#define GEN_PASS_DEF_FOLDUNITEXTENTDIMSPASS
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
// Pass helpers
//===----------------------------------------------------------------------===//

static LogicalResult
foldUnitDimsOnGlobal(IRRewriter &rewriter, IREE::Util::GlobalOpInterface global,
SmallVector<IREE::Util::GlobalLoadOpInterface> loadOps,
SmallVector<IREE::Util::GlobalStoreOpInterface> storeOps,
SymbolTable moduleSymbols) {
// Create a new transformed GlobalOp.
SmallVector<int64_t> newShape;
auto globalType = cast<RankedTensorType>(global.getGlobalType());
for (auto size : globalType.getShape()) {
if (size != 1) {
newShape.push_back(size);
}
}
auto newGlobalType = globalType.clone(newShape);
auto initialValue = global.getGlobalInitialValue();
// TODO: Handle non-uninitialized cases.
auto uninitializedAttr =
llvm::dyn_cast_if_present<IREE::Util::UninitializedAttr>(initialValue);
if (initialValue && !uninitializedAttr)
return success();
TypedAttr newInitialValue;
if (initialValue) {
newInitialValue = IREE::Util::UninitializedAttr::get(rewriter.getContext(),
newGlobalType);
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(global);
auto newGlobal =
clone(rewriter, global, global->getResultTypes(), global->getOperands());
newGlobal.setGlobalType(newGlobalType);
newGlobal.setGlobalInitialValue(newInitialValue);

// Rewrite loads and stores to use the new global.
auto expandShapeReInds =
getReassociationIndicesForReshape(globalType, newGlobalType);
if (!expandShapeReInds) {
return failure();
}

for (auto load : loadOps) {
rewriter.setInsertionPoint(load);
auto newLoad = clone(rewriter, load, {newGlobalType}, load->getOperands());
newLoad.setGlobalAttr(FlatSymbolRefAttr::get(newGlobal.getGlobalName()));
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
load, globalType, newLoad->getResult(0), expandShapeReInds.value());
}
for (auto store : storeOps) {
rewriter.setInsertionPoint(store);
Value collapse = rewriter.create<tensor::CollapseShapeOp>(
store.getLoc(), newGlobalType, store->getOperand(0),
expandShapeReInds.value());
auto newStore =
clone(rewriter, store, store->getResultTypes(), store->getOperands());
newStore.setGlobalAttr(FlatSymbolRefAttr::get(newGlobal.getGlobalName()));
newStore.setStoredGlobalValue(collapse);
rewriter.eraseOp(store);
}
rewriter.eraseOp(global);
return success();
}

namespace {
struct FoldUnitExtentDimsPass
: public IREE::Flow::impl::FoldUnitExtentDimsPassBase<
Expand All @@ -35,8 +101,35 @@ struct FoldUnitExtentDimsPass
} // namespace

void FoldUnitExtentDimsPass::runOnOperation() {
Operation *funcOp = getOperation();
auto funcOp = getOperation();
MLIRContext *context = &getContext();

Explorer explorer(funcOp, TraversalAction::RECURSE);
explorer.initialize();
IRRewriter rewriter(context);
SymbolTable moduleSymbols(funcOp);

// Fold unit dims of GlobalOpInterface ops.
explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) {
IREE::Util::GlobalOpInterface global = globalInfo->op;
auto tensorType = dyn_cast<RankedTensorType>(global.getGlobalType());
if (!tensorType || !global.isGlobalPrivate() || !global.isGlobalMutable()) {
return;
}
if (llvm::none_of(tensorType.getShape(),
[](int64_t size) { return size == 1; })) {
return;
}
SmallVector<IREE::Util::GlobalLoadOpInterface> loadOps =
llvm::to_vector(globalInfo->getLoads());
SmallVector<IREE::Util::GlobalStoreOpInterface> storeOps =
llvm::to_vector(globalInfo->getStores());
if (failed(foldUnitDimsOnGlobal(rewriter, global, loadOps, storeOps,
moduleSymbols))) {
return signalPassFailure();
}
});

RewritePatternSet foldUnitDimsPatterns(context);
linalg::ControlDropUnitDims options;
auto defaultFn = options.controlFn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def DeduplicateExecutablesPass :
}

def FoldUnitExtentDimsPass :
InterfacePass<"iree-flow-fold-unit-extent-dims", "mlir::FunctionOpInterface"> {
Pass<"iree-flow-fold-unit-extent-dims", "mlir::ModuleOp"> {
let summary = "Fold unit extent dimension of operations.";
let description = [{
Imports upstream patterns to fold unit extent dims but with IREE control.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fold-unit-extent-dims))" %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(iree-flow-fold-unit-extent-dims)" %s --split-input-file | FileCheck %s

util.func public @no_fold_unit_dims_in_dispatches(%arg0 : tensor<1x1x10xf32>) -> tensor<1x1x10xf32> {
%0 = tensor.empty() : tensor<1x1x10xf32>
Expand All @@ -21,3 +21,73 @@ util.func public @no_fold_unit_dims_in_dispatches(%arg0 : tensor<1x1x10xf32>) ->
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x1x10xf32>)
// CHECK: flow.return %[[GENERIC]]
// CHECK: util.return %[[DISPATCH]]


// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (0, 0)>
module @fold_unit_dims {
util.global private mutable @global {inlining_policy = #util.inline.never} = #util.uninitialized : tensor<1x32x1x1x64xf32>
util.global private mutable @unit_global = #util.uninitialized : tensor<1x1xf32>
util.func public @fold_global_unit_dims() -> tensor<32x64xf32> {
%global = util.global.load @global : tensor<1x32x1x1x64xf32>
%unit_global = util.global.load @unit_global : tensor<1x1xf32>
%collapsed = tensor.collapse_shape %global [[0, 1], [2, 3, 4]] : tensor<1x32x1x1x64xf32> into tensor<32x64xf32>
%0 = tensor.empty() : tensor<32x64xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %unit_global : tensor<32x64xf32>, tensor<1x1xf32>) outs(%0 : tensor<32x64xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<32x64xf32>
%expanded = tensor.expand_shape %1 [[0, 1], [2, 3, 4]] output_shape[1, 32, 1, 1, 64] : tensor<32x64xf32> into tensor<1x32x1x1x64xf32>
util.global.store %expanded, @global : tensor<1x32x1x1x64xf32>
util.return %1 : tensor<32x64xf32>
}
}

// CHECK: module @fold_unit_dims
// CHECK: util.global private mutable @[[GLOBAL:.+]] {inlining_policy = #util.inline.never} = #util.uninitialized : tensor<32x64xf32>
// CHECK: util.global private mutable @[[UNIT_GLOBAL:.+]] = #util.uninitialized : tensor<f32>
// CHECK: util.func public @fold_global_unit_dims
// CHECK: %[[LOAD0:.+]] = util.global.load @[[GLOBAL]] : tensor<32x64xf32>
// CHECK: %[[LOAD1:.+]] = util.global.load @[[UNIT_GLOBAL]] : tensor<f32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[LOAD0]], %[[LOAD1]]
// CHECK: util.global.store %[[GENERIC]], @[[GLOBAL]] : tensor<32x64xf32>
// CHECK: util.return %[[GENERIC]]

// -----

module @no_fold_immutable {
util.global private @global : tensor<1x32x1x1x64xf32>
util.func public @no_fold_global_unit_dims() -> tensor<32x64xf32> {
%global = util.global.load @global : tensor<1x32x1x1x64xf32>
%collapsed = tensor.collapse_shape %global [[0, 1], [2, 3, 4]] : tensor<1x32x1x1x64xf32> into tensor<32x64xf32>
util.return %collapsed : tensor<32x64xf32>
}
}

// CHECK: module @no_fold_immutable
// CHECK: util.global private @[[GLOBAL:.+]] : tensor<1x32x1x1x64xf32>
// CHECK: util.func public @no_fold_global_unit_dims
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]]
// CHECK: util.return %[[COLLAPSE]]

// -----

module @no_fold_public {
util.global public mutable @global : tensor<1x32x1x1x64xf32>
util.func public @no_fold_global_unit_dims() -> tensor<32x64xf32> {
%global = util.global.load @global : tensor<1x32x1x1x64xf32>
%collapsed = tensor.collapse_shape %global [[0, 1], [2, 3, 4]] : tensor<1x32x1x1x64xf32> into tensor<32x64xf32>
util.return %collapsed : tensor<32x64xf32>
}
}

// CHECK: module @no_fold_public
// CHECK: util.global public mutable @[[GLOBAL:.+]] : tensor<1x32x1x1x64xf32>
// CHECK: util.func public @no_fold_global_unit_dims
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]]
6 changes: 4 additions & 2 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ void buildGlobalOptimizationPassPipeline(
// dims as the unit dim folding pass updates indexing maps and is better
// at working with generics. By this point we have already done any
// specialized raising and the op names are no longer useful.
.addPass(createGeneralizeLinalgNamedOpsPass)
.addPass(IREE::Flow::createFoldUnitExtentDimsPass)
.addPass(createGeneralizeLinalgNamedOpsPass);

mainPassManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass());
FunctionLikeNest(mainPassManager)
.addPredicatedPass(clEnableFuseSiluHorizontalMatmul,
createFuseSiluHorizontalMatmulPass)
.addPass([&]() {
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Preprocessing/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ buildTransposeConvolutionPassPipeline(OpPassManager &passManager,
.addPass(GlobalOptimization::createDetachElementwiseFromNamedOpsPass)
.addPass(mlir::createLinalgNamedOpConversionPass)
.addPass(GlobalOptimization::createConvert1X1FilterConv2DToMatmulPass)
.addPass(createConvertConvToChannelsLastPass)
.addPass(IREE::Flow::createFoldUnitExtentDimsPass);
.addPass(createConvertConvToChannelsLastPass);
passManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
}
Expand Down

0 comments on commit 698b75c

Please sign in to comment.