diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h index 61f24255f305f7..207a0098e55870 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -22,6 +22,9 @@ namespace mlir { namespace func { class FuncOp; } // namespace func +namespace memref { +class MemRefDialect; +} // namespace memref namespace affine { class AffineForOp; @@ -48,6 +51,9 @@ createAffineLoopInvariantCodeMotionPass(); /// ops. std::unique_ptr> createAffineParallelizePass(); +/// Creates a pass that converts some memref operators to affine operators. +std::unique_ptr> createRaiseMemrefToAffine(); + /// Apply normalization transformations to affine loop-like ops. If /// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the /// loop is replaced by its loop body). diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td index b08e803345f76e..137d0d5bb595d0 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -397,6 +397,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> { let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"]; } +def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> { + let summary = "Turn some memref operators to affine operators where supported"; + let description = [{ + Raise memref.load and memref.store to affine.store and affine.load, inferring + the affine map of those operators if needed. This allows passes like --affine-scalrep + to optimize those loads and stores (forwarding them or eliminating them). + They can be turned back to memref dialect ops with --lower-affine. + }]; + let constructor = "mlir::affine::createRaiseMemrefToAffine()"; + let dependentDialects = ["memref::MemRefDialect"]; +} + def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> { let summary = "Simplify affine expressions in maps/sets and normalize " "memrefs"; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 5e7a6b6ca883c3..ab8e404631cdfe 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -282,10 +282,12 @@ bool mlir::affine::isValidDim(Value value) { return isValidDim(value, getAffineScope(defOp)); // This value has to be a block argument for an op that has the - // `AffineScope` trait or for an affine.for or affine.parallel. + // `AffineScope` trait or an induction var of an affine.for or + // affine.parallel. + if (isAffineInductionVar(value)) + return true; auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); - return parentOp && (parentOp->hasTrait() || - isa(parentOp)); + return parentOp && parentOp->hasTrait(); } // Value can be used as a dimension id iff it meets one of the following @@ -304,10 +306,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) { auto *op = value.getDefiningOp(); if (!op) { - // This value has to be a block argument for an affine.for or an + // This value has to be an induction var for an affine.for or an // affine.parallel. - auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); - return isa(parentOp); + return isAffineInductionVar(value); } // Affine apply operation is ok if all of its operands are ok. diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt index 772f15335d907f..d6fa7af4cad7d4 100644 --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRAffineTransforms LoopUnroll.cpp LoopUnrollAndJam.cpp PipelineDataTransfer.cpp + RaiseMemrefDialect.cpp ReifyValueBounds.cpp SuperVectorize.cpp SimplifyAffineStructures.cpp diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp index f28fb3acb7db7f..4d5ff5765ccc96 100644 --- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -13,9 +13,20 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include using namespace mlir; using namespace mlir::affine; diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp new file mode 100644 index 00000000000000..2fd47549000001 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp @@ -0,0 +1,168 @@ + + +#include "mlir/Dialect/Affine/Analysis/Utils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include +#include +#include + +namespace mlir { +namespace affine { +#define GEN_PASS_DEF_RAISEMEMREFDIALECT +#include "mlir/Dialect/Affine/Passes.h.inc" +} // namespace affine +} // namespace mlir + +#define DEBUG_TYPE "raise-memref-to-affine" + +using namespace mlir; +using namespace mlir::affine; + +namespace { + +static std::optional +findInListOrAdd(Value value, llvm::SmallVectorImpl &dims, + const std::function &isValidElement) { + + Value *loopIV = std::find(dims.begin(), dims.end(), value); + if (loopIV != dims.end()) { + // found an IV that already has an index + return {std::distance(dims.begin(), loopIV)}; + } + if (isValidElement(value)) { + // push this IV in the parameters + size_t idx = dims.size(); + dims.push_back(value); + return idx; + } + return std::nullopt; +} + +static LogicalResult toAffineExpr(Value value, AffineExpr &result, + llvm::SmallVectorImpl &affineDims, + llvm::SmallVectorImpl &affineSymbols) { + using namespace matchers; + IntegerAttr::ValueType cst; + if (matchPattern(value, m_ConstantInt(&cst))) { + result = getAffineConstantExpr(cst.getSExtValue(), value.getContext()); + return success(); + } + Value lhs; + Value rhs; + if (matchPattern(value, m_Op(m_Any(&lhs), m_Any(&rhs))) || + matchPattern(value, m_Op(m_Any(&lhs), m_Any(&rhs)))) { + AffineExpr lhsE; + AffineExpr rhsE; + if (succeeded(toAffineExpr(lhs, lhsE, affineDims, affineSymbols)) && + succeeded(toAffineExpr(rhs, rhsE, affineDims, affineSymbols))) { + AffineExprKind kind; + if (isa(value.getDefiningOp())) { + kind = mlir::AffineExprKind::Add; + } else { + kind = mlir::AffineExprKind::Mul; + } + result = getAffineBinaryOpExpr(kind, lhsE, rhsE); + return success(); + } + } + + if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) { + return affine::isValidSymbol(v); + })) { + result = getAffineSymbolExpr(*dimIx, value.getContext()); + return success(); + } + + if (auto dimIx = findInListOrAdd( + value, affineDims, [](Value v) { return affine::isValidDim(v); })) { + + result = getAffineDimExpr(*dimIx, value.getContext()); + return success(); + } + + return failure(); +} + +static LogicalResult +computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map, + llvm::SmallVectorImpl &mapArgs) { + llvm::SmallVector results; + llvm::SmallVector symbols; + llvm::SmallVector dims; + + for (auto indexExpr : indices) { + if (failed( + toAffineExpr(indexExpr, results.emplace_back(), dims, symbols))) { + return failure(); + } + } + + map = AffineMap::get(dims.size(), symbols.size(), results, ctx); + + dims.append(symbols); + mapArgs.swap(dims); + return success(); +} + +struct RaiseMemrefDialect + : public affine::impl::RaiseMemrefDialectBase { + + void runOnOperation() override { + auto *ctx = &getContext(); + Operation *op = getOperation(); + IRRewriter rewriter(ctx); + AffineMap map; + SmallVector mapArgs; + op->walk([&](Operation *op) { + rewriter.setInsertionPoint(op); + if (auto store = llvm::dyn_cast_or_null(op)) { + + if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map, + mapArgs))) { + rewriter.replaceOpWithNewOp( + op, store.getValueToStore(), store.getMemRef(), map, mapArgs); + } else { + LLVM_DEBUG(llvm::dbgs() + << "[affine] Cannot raise memref op: " << op << "\n"); + } + + } else if (auto load = llvm::dyn_cast_or_null(op)) { + + if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map, + mapArgs))) { + rewriter.replaceOpWithNewOp(op, load.getMemRef(), map, + mapArgs); + } else { + LLVM_DEBUG(llvm::dbgs() + << "[affine] Cannot raise memref op: " << op << "\n"); + } + } + }); + } +}; + +} // namespace + +std::unique_ptr> +mlir::affine::createRaiseMemrefToAffine() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir new file mode 100644 index 00000000000000..d529e2c0c907a6 --- /dev/null +++ b/mlir/test/Dialect/Affine/raise-memref.mlir @@ -0,0 +1,130 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s + +// CHECK-LABEL: func @reduce_window_max() { +func.func @reduce_window_max() { + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.alloc() : memref<1x8x8x64xf32> + %1 = memref.alloc() : memref<1x18x18x64xf32> + affine.for %arg0 = 0 to 1 { + affine.for %arg1 = 0 to 8 { + affine.for %arg2 = 0 to 8 { + affine.for %arg3 = 0 to 64 { + memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + } + } + } + } + affine.for %arg0 = 0 to 1 { + affine.for %arg1 = 0 to 8 { + affine.for %arg2 = 0 to 8 { + affine.for %arg3 = 0 to 64 { + affine.for %arg4 = 0 to 1 { + affine.for %arg5 = 0 to 3 { + affine.for %arg6 = 0 to 3 { + affine.for %arg7 = 0 to 1 { + %2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + %21 = arith.addi %arg0, %arg4 : index + %22 = arith.constant 2 : index + %23 = arith.muli %arg1, %22 : index + %24 = arith.addi %23, %arg5 : index + %25 = arith.muli %arg2, %22 : index + %26 = arith.addi %25, %arg6 : index + %27 = arith.addi %arg3, %arg7 : index + %3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32> + %4 = arith.cmpf ogt, %2, %3 : f32 + %5 = arith.select %4, %2, %3 : f32 + memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + } + } + } + } + } + } + } + } + return +} + +// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32> +// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32> +// CHECK: affine.for %[[arg0:.*]] = 0 to 1 { +// CHECK: affine.for %[[arg1:.*]] = 0 to 8 { +// CHECK: affine.for %[[arg2:.*]] = 0 to 8 { +// CHECK: affine.for %[[arg3:.*]] = 0 to 64 { +// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: affine.for %[[a0:.*]] = 0 to 1 { +// CHECK: affine.for %[[a1:.*]] = 0 to 8 { +// CHECK: affine.for %[[a2:.*]] = 0 to 8 { +// CHECK: affine.for %[[a3:.*]] = 0 to 64 { +// CHECK: affine.for %[[a4:.*]] = 0 to 1 { +// CHECK: affine.for %[[a5:.*]] = 0 to 3 { +// CHECK: affine.for %[[a6:.*]] = 0 to 3 { +// CHECK: affine.for %[[a7:.*]] = 0 to 1 { +// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> +// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32> +// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32 +// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32 +// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } + +func.func @symbols(%N : index) { + %0 = memref.alloc() : memref<1024x1024xf32> + %1 = memref.alloc() : memref<1024x1024xf32> + %2 = memref.alloc() : memref<1024x1024xf32> + %cst1 = arith.constant 1 : index + %cst2 = arith.constant 2 : index + affine.for %i = 0 to %N { + affine.for %j = 0 to %N { + %7 = memref.load %2[%i, %j] : memref<1024x1024xf32> + %10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index { + %12 = arith.muli %N, %cst2 : index + %13 = arith.addi %12, %cst1 : index + %14 = arith.addi %13, %j : index + %5 = memref.load %0[%i, %12] : memref<1024x1024xf32> + %6 = memref.load %1[%14, %j] : memref<1024x1024xf32> + %8 = arith.mulf %5, %6 : f32 + %9 = arith.addf %7, %8 : f32 + %4 = arith.addi %N, %cst1 : index + %11 = arith.addi %ax, %cst1 : index + memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol + memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be lowered + %something = "ab.v"() : () -> index + memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered + affine.yield %11 : index + } + } + } + return +} + +// CHECK: %[[cst1:.*]] = arith.constant 1 : index +// CHECK: %[[v0:.*]] = memref.alloc() : memref< +// CHECK: %[[v1:.*]] = memref.alloc() : memref< +// CHECK: %[[v2:.*]] = memref.alloc() : memref< +// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 { +// CHECK-NEXT: affine.for %[[a2:.*]] = 0 to %arg0 { +// CHECK-NEXT: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32> +// CHECK-NEXT: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) { +// CHECK-NEXT: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32> +// CHECK-NEXT: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32> +// CHECK-NEXT: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]] +// CHECK-NEXT: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]] +// CHECK-NEXT: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]] +// CHECK-NEXT: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32> +// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32> +// CHECK-NEXT: %[[lhs7:.*]] = "ab.v" +// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32> +// CHECK-NEXT: affine.yield %[[lhs6]]