Skip to content

Commit

Permalink
Add --affine-raise-from-memref
Browse files Browse the repository at this point in the history
Restrict isValidDim to induction vars, and not iter_args
  • Loading branch information
oowekyala committed Nov 29, 2024
1 parent 98204a2 commit 6127ad9
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 6 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ namespace mlir {
namespace func {
class FuncOp;
} // namespace func
namespace memref {
class MemRefDialect;
} // namespace memref

namespace affine {
class AffineForOp;
Expand All @@ -48,6 +51,9 @@ createAffineLoopInvariantCodeMotionPass();
/// ops.
std::unique_ptr<OperationPass<func::FuncOp>> createAffineParallelizePass();

/// Creates a pass that converts some memref operators to affine operators.
std::unique_ptr<OperationPass<func::FuncOp>> createRaiseMemrefToAffine();

/// Apply normalization transformations to affine loop-like ops. If
/// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the
/// loop is replaced by its loop body).
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
}

def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
let summary = "Turn some memref operators to affine operators where supported";
let description = [{
Raise memref.load and memref.store to affine.store and affine.load, inferring
the affine map of those operators if needed. This allows passes like --affine-scalrep
to optimize those loads and stores (forwarding them or eliminating them).
They can be turned back to memref dialect ops with --lower-affine.
}];
let constructor = "mlir::affine::createRaiseMemrefToAffine()";
let dependentDialects = ["memref::MemRefDialect"];
}

def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
let summary = "Simplify affine expressions in maps/sets and normalize "
"memrefs";
Expand Down
13 changes: 7 additions & 6 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,12 @@ bool mlir::affine::isValidDim(Value value) {
return isValidDim(value, getAffineScope(defOp));

// This value has to be a block argument for an op that has the
// `AffineScope` trait or for an affine.for or affine.parallel.
// `AffineScope` trait or an induction var of an affine.for or
// affine.parallel.
if (isAffineInductionVar(value))
return true;
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
isa<AffineForOp, AffineParallelOp>(parentOp));
return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
}

// Value can be used as a dimension id iff it meets one of the following
Expand All @@ -306,10 +308,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) {

auto *op = value.getDefiningOp();
if (!op) {
// This value has to be a block argument for an affine.for or an
// This value has to be an induction var for an affine.for or an
// affine.parallel.
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
return isa<AffineForOp, AffineParallelOp>(parentOp);
return isAffineInductionVar(value);
}

// Affine apply operation is ok if all of its operands are ok.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
LoopUnroll.cpp
LoopUnrollAndJam.cpp
PipelineDataTransfer.cpp
RaiseMemrefDialect.cpp
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,20 @@

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iterator>

using namespace mlir;
using namespace mlir::affine;
Expand Down
168 changes: 168 additions & 0 deletions mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@


#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iterator>
#include <memory>
#include <optional>

namespace mlir {
namespace affine {
#define GEN_PASS_DEF_RAISEMEMREFDIALECT
#include "mlir/Dialect/Affine/Passes.h.inc"
} // namespace affine
} // namespace mlir

#define DEBUG_TYPE "raise-memref-to-affine"

using namespace mlir;
using namespace mlir::affine;

namespace {

static std::optional<size_t>
findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
const std::function<bool(Value)> &isValidElement) {

Value *loopIV = std::find(dims.begin(), dims.end(), value);
if (loopIV != dims.end()) {
// found an IV that already has an index
return {std::distance(dims.begin(), loopIV)};
}
if (isValidElement(value)) {
// push this IV in the parameters
size_t idx = dims.size();
dims.push_back(value);
return idx;
}
return std::nullopt;
}

static LogicalResult toAffineExpr(Value value, AffineExpr &result,
llvm::SmallVectorImpl<Value> &affineDims,
llvm::SmallVectorImpl<Value> &affineSymbols) {
using namespace matchers;
IntegerAttr::ValueType cst;
if (matchPattern(value, m_ConstantInt(&cst))) {
result = getAffineConstantExpr(cst.getSExtValue(), value.getContext());
return success();
}
Value lhs;
Value rhs;
if (matchPattern(value, m_Op<arith::AddIOp>(m_Any(&lhs), m_Any(&rhs))) ||
matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
AffineExpr lhsE;
AffineExpr rhsE;
if (succeeded(toAffineExpr(lhs, lhsE, affineDims, affineSymbols)) &&
succeeded(toAffineExpr(rhs, rhsE, affineDims, affineSymbols))) {
AffineExprKind kind;
if (isa<arith::AddIOp>(value.getDefiningOp())) {
kind = mlir::AffineExprKind::Add;
} else {
kind = mlir::AffineExprKind::Mul;
}
result = getAffineBinaryOpExpr(kind, lhsE, rhsE);
return success();
}
}

if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
return affine::isValidSymbol(v);
})) {
result = getAffineSymbolExpr(*dimIx, value.getContext());
return success();
}

if (auto dimIx = findInListOrAdd(
value, affineDims, [](Value v) { return affine::isValidDim(v); })) {

result = getAffineDimExpr(*dimIx, value.getContext());
return success();
}

return failure();
}

static LogicalResult
computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
llvm::SmallVectorImpl<Value> &mapArgs) {
llvm::SmallVector<AffineExpr> results;
llvm::SmallVector<Value, 2> symbols;
llvm::SmallVector<Value, 8> dims;

for (auto indexExpr : indices) {
if (failed(
toAffineExpr(indexExpr, results.emplace_back(), dims, symbols))) {
return failure();
}
}

map = AffineMap::get(dims.size(), symbols.size(), results, ctx);

dims.append(symbols);
mapArgs.swap(dims);
return success();
}

struct RaiseMemrefDialect
: public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {

void runOnOperation() override {
auto *ctx = &getContext();
Operation *op = getOperation();
IRRewriter rewriter(ctx);
AffineMap map;
SmallVector<Value> mapArgs;
op->walk([&](Operation *op) {
rewriter.setInsertionPoint(op);
if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {

if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
mapArgs))) {
rewriter.replaceOpWithNewOp<AffineStoreOp>(
op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
} else {
LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise memref op: " << op << "\n");
}

} else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {

if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
mapArgs))) {
rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
mapArgs);
} else {
LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise memref op: " << op << "\n");
}
}
});
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::affine::createRaiseMemrefToAffine() {
return std::make_unique<RaiseMemrefDialect>();
}
130 changes: 130 additions & 0 deletions mlir/test/Dialect/Affine/raise-memref.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s

// CHECK-LABEL: func @reduce_window_max() {
func.func @reduce_window_max() {
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.alloc() : memref<1x8x8x64xf32>
%1 = memref.alloc() : memref<1x18x18x64xf32>
affine.for %arg0 = 0 to 1 {
affine.for %arg1 = 0 to 8 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 64 {
memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
}
}
}
}
affine.for %arg0 = 0 to 1 {
affine.for %arg1 = 0 to 8 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 64 {
affine.for %arg4 = 0 to 1 {
affine.for %arg5 = 0 to 3 {
affine.for %arg6 = 0 to 3 {
affine.for %arg7 = 0 to 1 {
%2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
%21 = arith.addi %arg0, %arg4 : index
%22 = arith.constant 2 : index
%23 = arith.muli %arg1, %22 : index
%24 = arith.addi %23, %arg5 : index
%25 = arith.muli %arg2, %22 : index
%26 = arith.addi %25, %arg6 : index
%27 = arith.addi %arg3, %arg7 : index
%3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32>
%4 = arith.cmpf ogt, %2, %3 : f32
%5 = arith.select %4, %2, %3 : f32
memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
}
}
}
}
}
}
}
}
return
}

// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
// CHECK: affine.for %[[arg0:.*]] = 0 to 1 {
// CHECK: affine.for %[[arg1:.*]] = 0 to 8 {
// CHECK: affine.for %[[arg2:.*]] = 0 to 8 {
// CHECK: affine.for %[[arg3:.*]] = 0 to 64 {
// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: affine.for %[[a0:.*]] = 0 to 1 {
// CHECK: affine.for %[[a1:.*]] = 0 to 8 {
// CHECK: affine.for %[[a2:.*]] = 0 to 8 {
// CHECK: affine.for %[[a3:.*]] = 0 to 64 {
// CHECK: affine.for %[[a4:.*]] = 0 to 1 {
// CHECK: affine.for %[[a5:.*]] = 0 to 3 {
// CHECK: affine.for %[[a6:.*]] = 0 to 3 {
// CHECK: affine.for %[[a7:.*]] = 0 to 1 {
// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32>
// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }

func.func @symbols(%N : index) {
%0 = memref.alloc() : memref<1024x1024xf32>
%1 = memref.alloc() : memref<1024x1024xf32>
%2 = memref.alloc() : memref<1024x1024xf32>
%cst1 = arith.constant 1 : index
%cst2 = arith.constant 2 : index
affine.for %i = 0 to %N {
affine.for %j = 0 to %N {
%7 = memref.load %2[%i, %j] : memref<1024x1024xf32>
%10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index {
%12 = arith.muli %N, %cst2 : index
%13 = arith.addi %12, %cst1 : index
%14 = arith.addi %13, %j : index
%5 = memref.load %0[%i, %12] : memref<1024x1024xf32>
%6 = memref.load %1[%14, %j] : memref<1024x1024xf32>
%8 = arith.mulf %5, %6 : f32
%9 = arith.addf %7, %8 : f32
%4 = arith.addi %N, %cst1 : index
%11 = arith.addi %ax, %cst1 : index
memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol
memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be lowered
%something = "ab.v"() : () -> index
memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered
affine.yield %11 : index
}
}
}
return
}

// CHECK: %[[cst1:.*]] = arith.constant 1 : index
// CHECK: %[[v0:.*]] = memref.alloc() : memref<
// CHECK: %[[v1:.*]] = memref.alloc() : memref<
// CHECK: %[[v2:.*]] = memref.alloc() : memref<
// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 {
// CHECK-NEXT: affine.for %[[a2:.*]] = 0 to %arg0 {
// CHECK-NEXT: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
// CHECK-NEXT: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
// CHECK-NEXT: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32>
// CHECK-NEXT: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32>
// CHECK-NEXT: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
// CHECK-NEXT: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
// CHECK-NEXT: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
// CHECK-NEXT: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32>
// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32>
// CHECK-NEXT: %[[lhs7:.*]] = "ab.v"
// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32>
// CHECK-NEXT: affine.yield %[[lhs6]]

0 comments on commit 6127ad9

Please sign in to comment.