Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][affine] Add pass --affine-raise-from-memref #114032

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ namespace mlir {
namespace func {
class FuncOp;
} // namespace func
namespace memref {
class MemRefDialect;
} // namespace memref

namespace affine {
class AffineForOp;
Expand All @@ -48,6 +51,9 @@ createAffineLoopInvariantCodeMotionPass();
/// ops.
std::unique_ptr<OperationPass<func::FuncOp>> createAffineParallelizePass();

/// Creates a pass that converts some memref operators to affine operators.
std::unique_ptr<OperationPass<func::FuncOp>> createRaiseMemrefToAffine();

/// Apply normalization transformations to affine loop-like ops. If
/// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the
/// loop is replaced by its loop body).
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
}

def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
let summary = "Turn some memref operators to affine operators where supported";
let description = [{
Raise memref.load and memref.store to affine.store and affine.load, inferring
the affine map of those operators if needed. This allows passes like --affine-scalrep
to optimize those loads and stores (forwarding them or eliminating them).
They can be turned back to memref dialect ops with --lower-affine.
}];
let constructor = "mlir::affine::createRaiseMemrefToAffine()";
let dependentDialects = ["affine::AffineDialect"];
}

def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
let summary = "Simplify affine expressions in maps/sets and normalize "
"memrefs";
Expand Down
13 changes: 7 additions & 6 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,12 @@ bool mlir::affine::isValidDim(Value value) {
return isValidDim(value, getAffineScope(defOp));

// This value has to be a block argument for an op that has the
// `AffineScope` trait or for an affine.for or affine.parallel.
// `AffineScope` trait or an induction var of an affine.for or
// affine.parallel.
if (isAffineInductionVar(value))
return true;
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
isa<AffineForOp, AffineParallelOp>(parentOp));
return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
}

// Value can be used as a dimension id iff it meets one of the following
Expand All @@ -306,10 +308,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) {

auto *op = value.getDefiningOp();
if (!op) {
// This value has to be a block argument for an affine.for or an
// This value has to be an induction var for an affine.for or an
// affine.parallel.
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
return isa<AffineForOp, AffineParallelOp>(parentOp);
return isAffineInductionVar(value);
}

// Affine apply operation is ok if all of its operands are ok.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
LoopUnroll.cpp
LoopUnrollAndJam.cpp
PipelineDataTransfer.cpp
RaiseMemrefDialect.cpp
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
Expand Down
172 changes: 172 additions & 0 deletions mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements functionality to convert memref load and store ops to
// the corresponding affine ops, inferring the affine map as needed.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"

namespace mlir {
namespace affine {
#define GEN_PASS_DEF_RAISEMEMREFDIALECT
#include "mlir/Dialect/Affine/Passes.h.inc"
} // namespace affine
} // namespace mlir

#define DEBUG_TYPE "raise-memref-to-affine"

using namespace mlir;
using namespace mlir::affine;

namespace {

/// Find the index of the given value in the `dims` list,
/// and append it if it was not already in the list. The
/// dims list is a list of symbols or dimensions of the
/// affine map. Within the results of an affine map, they
/// are identified by their index, which is why we need
/// this function.
static std::optional<size_t>
oowekyala marked this conversation as resolved.
Show resolved Hide resolved
findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
function_ref<bool(Value)> isValidElement) {

Value *loopIV = std::find(dims.begin(), dims.end(), value);
if (loopIV != dims.end()) {
// We found an IV that already has an index, return that index.
return {std::distance(dims.begin(), loopIV)};
}
if (isValidElement(value)) {
// This is a valid element for the dim/symbol list, push this as a
// parameter.
size_t idx = dims.size();
dims.push_back(value);
return idx;
}
return std::nullopt;
}

/// Convert a value to an affine expr if possible. Adds dims and symbols
/// if needed.
static AffineExpr toAffineExpr(Value value,
llvm::SmallVectorImpl<Value> &affineDims,
llvm::SmallVectorImpl<Value> &affineSymbols) {
using namespace matchers;
IntegerAttr::ValueType cst;
if (matchPattern(value, m_ConstantInt(&cst))) {
return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
}
Value lhs;
Value rhs;
if (matchPattern(value, m_Op<arith::AddIOp>(m_Any(&lhs), m_Any(&rhs))) ||
matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
AffineExpr lhsE;
AffineExpr rhsE;
if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) &&
(rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) {
AffineExprKind kind;
if (isa<arith::AddIOp>(value.getDefiningOp())) {
kind = mlir::AffineExprKind::Add;
} else {
kind = mlir::AffineExprKind::Mul;
}
return getAffineBinaryOpExpr(kind, lhsE, rhsE);
}
}

if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
return affine::isValidSymbol(v);
})) {
return getAffineSymbolExpr(*dimIx, value.getContext());
}

if (auto dimIx = findInListOrAdd(
value, affineDims, [](Value v) { return affine::isValidDim(v); })) {

return getAffineDimExpr(*dimIx, value.getContext());
}

return {};
}

static LogicalResult
computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
llvm::SmallVectorImpl<Value> &mapArgs) {
SmallVector<AffineExpr> results;
SmallVector<Value> symbols;
SmallVector<Value> dims;

for (Value indexExpr : indices) {
AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
if (!res) {
return failure();
}
results.push_back(res);
}

map = AffineMap::get(dims.size(), symbols.size(), results, ctx);

dims.append(symbols);
mapArgs.swap(dims);
return success();
}

struct RaiseMemrefDialect
: public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {

void runOnOperation() override {
auto *ctx = &getContext();
Operation *op = getOperation();
IRRewriter rewriter(ctx);
AffineMap map;
SmallVector<Value> mapArgs;
op->walk([&](Operation *op) {
rewriter.setInsertionPoint(op);
if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {

if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
mapArgs))) {
rewriter.replaceOpWithNewOp<AffineStoreOp>(
op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
oowekyala marked this conversation as resolved.
Show resolved Hide resolved
return;
}

LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise memref op: " << op << "\n");

} else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
mapArgs))) {
rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
mapArgs);
return;
}
LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise memref op: " << op << "\n");
}
});
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::affine::createRaiseMemrefToAffine() {
return std::make_unique<RaiseMemrefDialect>();
}
118 changes: 118 additions & 0 deletions mlir/test/Dialect/Affine/raise-memref.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s

// CHECK-LABEL: func @reduce_window_max(
func.func @reduce_window_max() {
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.alloc() : memref<1x8x8x64xf32>
%1 = memref.alloc() : memref<1x18x18x64xf32>
affine.for %arg0 = 0 to 1 {
affine.for %arg1 = 0 to 8 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 64 {
memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
}
}
}
}
affine.for %arg0 = 0 to 1 {
affine.for %arg1 = 0 to 8 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 64 {
affine.for %arg4 = 0 to 1 {
affine.for %arg5 = 0 to 3 {
affine.for %arg6 = 0 to 3 {
affine.for %arg7 = 0 to 1 {
%2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
%21 = arith.addi %arg0, %arg4 : index
%22 = arith.constant 2 : index
%23 = arith.muli %arg1, %22 : index
%24 = arith.addi %23, %arg5 : index
%25 = arith.muli %arg2, %22 : index
%26 = arith.addi %25, %arg6 : index
%27 = arith.addi %arg3, %arg7 : index
%3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32>
%4 = arith.cmpf ogt, %2, %3 : f32
%5 = arith.select %4, %2, %3 : f32
memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
}
}
}
}
}
}
}
}
return
}

// CHECK: %[[cst:.*]] = arith.constant 0
// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
// CHECK: affine.for %[[arg0:.*]] =
// CHECK: affine.for %[[arg1:.*]] =
// CHECK: affine.for %[[arg2:.*]] =
// CHECK: affine.for %[[arg3:.*]] =
// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] :
// CHECK: affine.for %[[a0:.*]] =
// CHECK: affine.for %[[a1:.*]] =
// CHECK: affine.for %[[a2:.*]] =
// CHECK: affine.for %[[a3:.*]] =
// CHECK: affine.for %[[a4:.*]] =
// CHECK: affine.for %[[a5:.*]] =
// CHECK: affine.for %[[a6:.*]] =
// CHECK: affine.for %[[a7:.*]] =
// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] :
// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :

// CHECK-LABEL: func @symbols(
func.func @symbols(%N : index) {
oowekyala marked this conversation as resolved.
Show resolved Hide resolved
%0 = memref.alloc() : memref<1024x1024xf32>
%1 = memref.alloc() : memref<1024x1024xf32>
%2 = memref.alloc() : memref<1024x1024xf32>
%cst1 = arith.constant 1 : index
%cst2 = arith.constant 2 : index
affine.for %i = 0 to %N {
affine.for %j = 0 to %N {
%7 = memref.load %2[%i, %j] : memref<1024x1024xf32>
%10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index {
%12 = arith.muli %N, %cst2 : index
%13 = arith.addi %12, %cst1 : index
%14 = arith.addi %13, %j : index
%5 = memref.load %0[%i, %12] : memref<1024x1024xf32>
%6 = memref.load %1[%14, %j] : memref<1024x1024xf32>
%8 = arith.mulf %5, %6 : f32
%9 = arith.addf %7, %8 : f32
%4 = arith.addi %N, %cst1 : index
%11 = arith.addi %ax, %cst1 : index
memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol
memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be raised
%something = "ab.v"() : () -> index
memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be raised
affine.yield %11 : index
}
}
}
return
}

// CHECK: %[[cst1:.*]] = arith.constant 1 : index
// CHECK: %[[v0:.*]] = memref.alloc() : memref<
// CHECK: %[[v1:.*]] = memref.alloc() : memref<
// CHECK: %[[v2:.*]] = memref.alloc() : memref<
// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 {
// CHECK: affine.for %[[a2:.*]] = 0 to %arg0 {
// CHECK: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
// CHECK: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
// CHECK: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] :
// CHECK: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] :
// CHECK: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
// CHECK: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
// CHECK: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
// CHECK: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] :
// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] :
// CHECK: %[[lhs7:.*]] = "ab.v"
// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] :
// CHECK: affine.yield %[[lhs6]]
Loading