diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h index e152101236dc7a..c1b9c30d302dd0 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -22,6 +22,9 @@ namespace mlir { namespace func { class FuncOp; } // namespace func +namespace memref { +class MemRefDialect; +} // namespace memref namespace affine { class AffineForOp; @@ -48,6 +51,9 @@ createAffineLoopInvariantCodeMotionPass(); /// ops. std::unique_ptr> createAffineParallelizePass(); +/// Creates a pass that converts some memref operators to affine operators. +std::unique_ptr> createRaiseMemrefToAffine(); + /// Apply normalization transformations to affine loop-like ops. If /// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the /// loop is replaced by its loop body). diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td index 77073aa29da73e..a77bcac5ed407f 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -397,6 +397,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> { let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"]; } +def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> { + let summary = "Turn some memref operators to affine operators where supported"; + let description = [{ + Raise memref.load and memref.store to affine.store and affine.load, inferring + the affine map of those operators if needed. This allows passes like --affine-scalrep + to optimize those loads and stores (forwarding them or eliminating them). + They can be turned back to memref dialect ops with --lower-affine. + }]; + let constructor = "mlir::affine::createRaiseMemrefToAffine()"; + let dependentDialects = ["affine::AffineDialect"]; +} + def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> { let summary = "Simplify affine expressions in maps/sets and normalize " "memrefs"; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index dceebbfec586c8..06204188e14e2e 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -284,10 +284,12 @@ bool mlir::affine::isValidDim(Value value) { return isValidDim(value, getAffineScope(defOp)); // This value has to be a block argument for an op that has the - // `AffineScope` trait or for an affine.for or affine.parallel. + // `AffineScope` trait or an induction var of an affine.for or + // affine.parallel. + if (isAffineInductionVar(value)) + return true; auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); - return parentOp && (parentOp->hasTrait() || - isa(parentOp)); + return parentOp && parentOp->hasTrait(); } // Value can be used as a dimension id iff it meets one of the following @@ -306,10 +308,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) { auto *op = value.getDefiningOp(); if (!op) { - // This value has to be a block argument for an affine.for or an + // This value has to be an induction var for an affine.for or an // affine.parallel. - auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); - return isa(parentOp); + return isAffineInductionVar(value); } // Affine apply operation is ok if all of its operands are ok. diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt index c42789b01bc9fa..1c82822b2bd7f9 100644 --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms LoopUnroll.cpp LoopUnrollAndJam.cpp PipelineDataTransfer.cpp + RaiseMemrefDialect.cpp ReifyValueBounds.cpp SuperVectorize.cpp SimplifyAffineStructures.cpp diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp new file mode 100644 index 00000000000000..491d2e03c36bca --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp @@ -0,0 +1,187 @@ +//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements functionality to convert memref load and store ops to +// the corresponding affine ops, inferring the affine map as needed. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/Analysis/Utils.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace affine { +#define GEN_PASS_DEF_RAISEMEMREFDIALECT +#include "mlir/Dialect/Affine/Passes.h.inc" +} // namespace affine +} // namespace mlir + +#define DEBUG_TYPE "raise-memref-to-affine" + +using namespace mlir; +using namespace mlir::affine; + +namespace { + +/// Find the index of the given value in the `dims` list, +/// and append it if it was not already in the list. The +/// dims list is a list of symbols or dimensions of the +/// affine map. Within the results of an affine map, they +/// are identified by their index, which is why we need +/// this function. +static std::optional +findInListOrAdd(Value value, llvm::SmallVectorImpl &dims, + function_ref isValidElement) { + + Value *loopIV = std::find(dims.begin(), dims.end(), value); + if (loopIV != dims.end()) { + // We found an IV that already has an index, return that index. + return {std::distance(dims.begin(), loopIV)}; + } + if (isValidElement(value)) { + // This is a valid element for the dim/symbol list, push this as a + // parameter. + size_t idx = dims.size(); + dims.push_back(value); + return idx; + } + return std::nullopt; +} + +/// Convert a value to an affine expr if possible. Adds dims and symbols +/// if needed. +static AffineExpr toAffineExpr(Value value, + llvm::SmallVectorImpl &affineDims, + llvm::SmallVectorImpl &affineSymbols) { + using namespace matchers; + IntegerAttr::ValueType cst; + if (matchPattern(value, m_ConstantInt(&cst))) { + return getAffineConstantExpr(cst.getSExtValue(), value.getContext()); + } + + Operation *definingOp = value.getDefiningOp(); + if (llvm::isa_and_nonnull(definingOp) || + llvm::isa_and_nonnull(definingOp)) { + // TODO: replace recursion with explicit stack. + // For the moment this can be tolerated as we only recurse on + // arith.addi and arith.muli, so there cannot be any infinite + // recursion. The depth of these expressions should be in most + // cases very manageable, as affine expressions should be as + // simple as `a + b * c`. + AffineExpr lhsE = + toAffineExpr(definingOp->getOperand(0), affineDims, affineSymbols); + AffineExpr rhsE = + toAffineExpr(definingOp->getOperand(1), affineDims, affineSymbols); + + if (lhsE && rhsE) { + AffineExprKind kind; + if (isa(definingOp)) { + kind = mlir::AffineExprKind::Add; + } else { + kind = mlir::AffineExprKind::Mul; + + if (!lhsE.isSymbolicOrConstant() && !rhsE.isSymbolicOrConstant()) { + // This is not an affine expression, give up. + return {}; + } + } + return getAffineBinaryOpExpr(kind, lhsE, rhsE); + } + return {}; + } + + if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) { + return affine::isValidSymbol(v); + })) { + return getAffineSymbolExpr(*dimIx, value.getContext()); + } + + if (auto dimIx = findInListOrAdd( + value, affineDims, [](Value v) { return affine::isValidDim(v); })) { + + return getAffineDimExpr(*dimIx, value.getContext()); + } + + return {}; +} + +static LogicalResult +computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map, + llvm::SmallVectorImpl &mapArgs) { + SmallVector results; + SmallVector symbols; + SmallVector dims; + + for (Value indexExpr : indices) { + AffineExpr res = toAffineExpr(indexExpr, dims, symbols); + if (!res) { + return failure(); + } + results.push_back(res); + } + + map = AffineMap::get(dims.size(), symbols.size(), results, ctx); + + dims.append(symbols); + mapArgs.swap(dims); + return success(); +} + +struct RaiseMemrefDialect + : public affine::impl::RaiseMemrefDialectBase { + + void runOnOperation() override { + auto *ctx = &getContext(); + Operation *op = getOperation(); + IRRewriter rewriter(ctx); + AffineMap map; + SmallVector mapArgs; + op->walk([&](Operation *op) { + rewriter.setInsertionPoint(op); + if (auto store = llvm::dyn_cast_or_null(op)) { + + if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map, + mapArgs))) { + rewriter.replaceOpWithNewOp( + op, store.getValueToStore(), store.getMemRef(), map, mapArgs); + return; + } + + LLVM_DEBUG(llvm::dbgs() + << "[affine] Cannot raise memref op: " << op << "\n"); + + } else if (auto load = llvm::dyn_cast_or_null(op)) { + if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map, + mapArgs))) { + rewriter.replaceOpWithNewOp(op, load.getMemRef(), map, + mapArgs); + return; + } + LLVM_DEBUG(llvm::dbgs() + << "[affine] Cannot raise memref op: " << op << "\n"); + } + }); + } +}; + +} // namespace + +std::unique_ptr> +mlir::affine::createRaiseMemrefToAffine() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir new file mode 100644 index 00000000000000..00cc98de1f40fb --- /dev/null +++ b/mlir/test/Dialect/Affine/raise-memref.mlir @@ -0,0 +1,138 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s + +// CHECK-LABEL: func @reduce_window_max( +func.func @reduce_window_max() { + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.alloc() : memref<1x8x8x64xf32> + %1 = memref.alloc() : memref<1x18x18x64xf32> + affine.for %arg0 = 0 to 1 { + affine.for %arg1 = 0 to 8 { + affine.for %arg2 = 0 to 8 { + affine.for %arg3 = 0 to 64 { + memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + } + } + } + } + affine.for %arg0 = 0 to 1 { + affine.for %arg1 = 0 to 8 { + affine.for %arg2 = 0 to 8 { + affine.for %arg3 = 0 to 64 { + affine.for %arg4 = 0 to 1 { + affine.for %arg5 = 0 to 3 { + affine.for %arg6 = 0 to 3 { + affine.for %arg7 = 0 to 1 { + %2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + %21 = arith.addi %arg0, %arg4 : index + %22 = arith.constant 2 : index + %23 = arith.muli %arg1, %22 : index + %24 = arith.addi %23, %arg5 : index + %25 = arith.muli %arg2, %22 : index + %26 = arith.addi %25, %arg6 : index + %27 = arith.addi %arg3, %arg7 : index + %3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32> + %4 = arith.cmpf ogt, %2, %3 : f32 + %5 = arith.select %4, %2, %3 : f32 + memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + } + } + } + } + } + } + } + } + return +} + +// CHECK: %[[cst:.*]] = arith.constant 0 +// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32> +// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32> +// CHECK: affine.for %[[arg0:.*]] = +// CHECK: affine.for %[[arg1:.*]] = +// CHECK: affine.for %[[arg2:.*]] = +// CHECK: affine.for %[[arg3:.*]] = +// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : +// CHECK: affine.for %[[a0:.*]] = +// CHECK: affine.for %[[a1:.*]] = +// CHECK: affine.for %[[a2:.*]] = +// CHECK: affine.for %[[a3:.*]] = +// CHECK: affine.for %[[a4:.*]] = +// CHECK: affine.for %[[a5:.*]] = +// CHECK: affine.for %[[a6:.*]] = +// CHECK: affine.for %[[a7:.*]] = +// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : +// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : +// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32 +// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32 +// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : + +// CHECK-LABEL: func @symbols( +func.func @symbols(%N : index) { + %0 = memref.alloc() : memref<1024x1024xf32> + %1 = memref.alloc() : memref<1024x1024xf32> + %2 = memref.alloc() : memref<1024x1024xf32> + %cst1 = arith.constant 1 : index + %cst2 = arith.constant 2 : index + affine.for %i = 0 to %N { + affine.for %j = 0 to %N { + %7 = memref.load %2[%i, %j] : memref<1024x1024xf32> + %10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index { + %12 = arith.muli %N, %cst2 : index + %13 = arith.addi %12, %cst1 : index + %14 = arith.addi %13, %j : index + %5 = memref.load %0[%i, %12] : memref<1024x1024xf32> + %6 = memref.load %1[%14, %j] : memref<1024x1024xf32> + %8 = arith.mulf %5, %6 : f32 + %9 = arith.addf %7, %8 : f32 + %4 = arith.addi %N, %cst1 : index + %11 = arith.addi %ax, %cst1 : index + memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol + memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be raised + %something = "ab.v"() : () -> index + memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be raised + affine.yield %11 : index + } + } + } + return +} + +// CHECK: %[[cst1:.*]] = arith.constant 1 : index +// CHECK: %[[v0:.*]] = memref.alloc() : memref< +// CHECK: %[[v1:.*]] = memref.alloc() : memref< +// CHECK: %[[v2:.*]] = memref.alloc() : memref< +// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 { +// CHECK: affine.for %[[a2:.*]] = 0 to %arg0 { +// CHECK: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32> +// CHECK: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) { +// CHECK: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : +// CHECK: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : +// CHECK: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]] +// CHECK: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]] +// CHECK: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]] +// CHECK: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : +// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : +// CHECK: %[[lhs7:.*]] = "ab.v" +// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : +// CHECK: affine.yield %[[lhs6]] + + +// CHECK-LABEL: func @non_affine( +func.func @non_affine(%N : index) { + %2 = memref.alloc() : memref<1024x1024xf32> + affine.for %i = 0 to %N { + affine.for %j = 0 to %N { + %ij = arith.muli %i, %j : index + %7 = memref.load %2[%i, %ij] : memref<1024x1024xf32> + memref.store %7, %2[%ij, %ij] : memref<1024x1024xf32> + } + } + return +} + +// CHECK: affine.for %[[i:.*]] = +// CHECK: affine.for %[[j:.*]] = +// CHECK: %[[ij:.*]] = arith.muli %[[i]], %[[j]] +// CHECK: %[[v:.*]] = memref.load %{{.*}}[%[[i]], %[[ij]]] +// CHECK: memref.store %[[v]], %{{.*}}[%[[ij]], %[[ij]]] \ No newline at end of file