From 4ca124fbf17a4930154fd8385ccf92ee84ecf887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= Date: Mon, 28 Oct 2024 16:51:26 +0100 Subject: [PATCH 1/4] Add --affine-raise-from-memref Restrict isValidDim to induction vars, and not iter_args --- mlir/include/mlir/Dialect/Affine/Passes.h | 6 + mlir/include/mlir/Dialect/Affine/Passes.td | 12 ++ mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 13 +- .../Dialect/Affine/Transforms/CMakeLists.txt | 1 + .../Affine/Transforms/DecomposeAffineOps.cpp | 11 ++ .../Affine/Transforms/RaiseMemrefDialect.cpp | 168 ++++++++++++++++++ mlir/test/Dialect/Affine/raise-memref.mlir | 130 ++++++++++++++ 7 files changed, 335 insertions(+), 6 deletions(-) create mode 100644 mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp create mode 100644 mlir/test/Dialect/Affine/raise-memref.mlir diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h index e152101236dc7a..c1b9c30d302dd0 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -22,6 +22,9 @@ namespace mlir { namespace func { class FuncOp; } // namespace func +namespace memref { +class MemRefDialect; +} // namespace memref namespace affine { class AffineForOp; @@ -48,6 +51,9 @@ createAffineLoopInvariantCodeMotionPass(); /// ops. std::unique_ptr> createAffineParallelizePass(); +/// Creates a pass that converts some memref operators to affine operators. +std::unique_ptr> createRaiseMemrefToAffine(); + /// Apply normalization transformations to affine loop-like ops. If /// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the /// loop is replaced by its loop body). diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td index 77073aa29da73e..43ce9dae93441a 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -397,6 +397,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> { let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"]; } +def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> { + let summary = "Turn some memref operators to affine operators where supported"; + let description = [{ + Raise memref.load and memref.store to affine.store and affine.load, inferring + the affine map of those operators if needed. This allows passes like --affine-scalrep + to optimize those loads and stores (forwarding them or eliminating them). + They can be turned back to memref dialect ops with --lower-affine. + }]; + let constructor = "mlir::affine::createRaiseMemrefToAffine()"; + let dependentDialects = ["memref::MemRefDialect"]; +} + def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> { let summary = "Simplify affine expressions in maps/sets and normalize " "memrefs"; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index dceebbfec586c8..06204188e14e2e 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -284,10 +284,12 @@ bool mlir::affine::isValidDim(Value value) { return isValidDim(value, getAffineScope(defOp)); // This value has to be a block argument for an op that has the - // `AffineScope` trait or for an affine.for or affine.parallel. + // `AffineScope` trait or an induction var of an affine.for or + // affine.parallel. + if (isAffineInductionVar(value)) + return true; auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); - return parentOp && (parentOp->hasTrait() || - isa(parentOp)); + return parentOp && parentOp->hasTrait(); } // Value can be used as a dimension id iff it meets one of the following @@ -306,10 +308,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) { auto *op = value.getDefiningOp(); if (!op) { - // This value has to be a block argument for an affine.for or an + // This value has to be an induction var for an affine.for or an // affine.parallel. - auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); - return isa(parentOp); + return isAffineInductionVar(value); } // Affine apply operation is ok if all of its operands are ok. diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt index c42789b01bc9fa..1c82822b2bd7f9 100644 --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms LoopUnroll.cpp LoopUnrollAndJam.cpp PipelineDataTransfer.cpp + RaiseMemrefDialect.cpp ReifyValueBounds.cpp SuperVectorize.cpp SimplifyAffineStructures.cpp diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp index f28fb3acb7db7f..4d5ff5765ccc96 100644 --- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -13,9 +13,20 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include using namespace mlir; using namespace mlir::affine; diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp new file mode 100644 index 00000000000000..2fd47549000001 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp @@ -0,0 +1,168 @@ + + +#include "mlir/Dialect/Affine/Analysis/Utils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include +#include +#include + +namespace mlir { +namespace affine { +#define GEN_PASS_DEF_RAISEMEMREFDIALECT +#include "mlir/Dialect/Affine/Passes.h.inc" +} // namespace affine +} // namespace mlir + +#define DEBUG_TYPE "raise-memref-to-affine" + +using namespace mlir; +using namespace mlir::affine; + +namespace { + +static std::optional +findInListOrAdd(Value value, llvm::SmallVectorImpl &dims, + const std::function &isValidElement) { + + Value *loopIV = std::find(dims.begin(), dims.end(), value); + if (loopIV != dims.end()) { + // found an IV that already has an index + return {std::distance(dims.begin(), loopIV)}; + } + if (isValidElement(value)) { + // push this IV in the parameters + size_t idx = dims.size(); + dims.push_back(value); + return idx; + } + return std::nullopt; +} + +static LogicalResult toAffineExpr(Value value, AffineExpr &result, + llvm::SmallVectorImpl &affineDims, + llvm::SmallVectorImpl &affineSymbols) { + using namespace matchers; + IntegerAttr::ValueType cst; + if (matchPattern(value, m_ConstantInt(&cst))) { + result = getAffineConstantExpr(cst.getSExtValue(), value.getContext()); + return success(); + } + Value lhs; + Value rhs; + if (matchPattern(value, m_Op(m_Any(&lhs), m_Any(&rhs))) || + matchPattern(value, m_Op(m_Any(&lhs), m_Any(&rhs)))) { + AffineExpr lhsE; + AffineExpr rhsE; + if (succeeded(toAffineExpr(lhs, lhsE, affineDims, affineSymbols)) && + succeeded(toAffineExpr(rhs, rhsE, affineDims, affineSymbols))) { + AffineExprKind kind; + if (isa(value.getDefiningOp())) { + kind = mlir::AffineExprKind::Add; + } else { + kind = mlir::AffineExprKind::Mul; + } + result = getAffineBinaryOpExpr(kind, lhsE, rhsE); + return success(); + } + } + + if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) { + return affine::isValidSymbol(v); + })) { + result = getAffineSymbolExpr(*dimIx, value.getContext()); + return success(); + } + + if (auto dimIx = findInListOrAdd( + value, affineDims, [](Value v) { return affine::isValidDim(v); })) { + + result = getAffineDimExpr(*dimIx, value.getContext()); + return success(); + } + + return failure(); +} + +static LogicalResult +computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map, + llvm::SmallVectorImpl &mapArgs) { + llvm::SmallVector results; + llvm::SmallVector symbols; + llvm::SmallVector dims; + + for (auto indexExpr : indices) { + if (failed( + toAffineExpr(indexExpr, results.emplace_back(), dims, symbols))) { + return failure(); + } + } + + map = AffineMap::get(dims.size(), symbols.size(), results, ctx); + + dims.append(symbols); + mapArgs.swap(dims); + return success(); +} + +struct RaiseMemrefDialect + : public affine::impl::RaiseMemrefDialectBase { + + void runOnOperation() override { + auto *ctx = &getContext(); + Operation *op = getOperation(); + IRRewriter rewriter(ctx); + AffineMap map; + SmallVector mapArgs; + op->walk([&](Operation *op) { + rewriter.setInsertionPoint(op); + if (auto store = llvm::dyn_cast_or_null(op)) { + + if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map, + mapArgs))) { + rewriter.replaceOpWithNewOp( + op, store.getValueToStore(), store.getMemRef(), map, mapArgs); + } else { + LLVM_DEBUG(llvm::dbgs() + << "[affine] Cannot raise memref op: " << op << "\n"); + } + + } else if (auto load = llvm::dyn_cast_or_null(op)) { + + if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map, + mapArgs))) { + rewriter.replaceOpWithNewOp(op, load.getMemRef(), map, + mapArgs); + } else { + LLVM_DEBUG(llvm::dbgs() + << "[affine] Cannot raise memref op: " << op << "\n"); + } + } + }); + } +}; + +} // namespace + +std::unique_ptr> +mlir::affine::createRaiseMemrefToAffine() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir new file mode 100644 index 00000000000000..d529e2c0c907a6 --- /dev/null +++ b/mlir/test/Dialect/Affine/raise-memref.mlir @@ -0,0 +1,130 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s + +// CHECK-LABEL: func @reduce_window_max() { +func.func @reduce_window_max() { + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.alloc() : memref<1x8x8x64xf32> + %1 = memref.alloc() : memref<1x18x18x64xf32> + affine.for %arg0 = 0 to 1 { + affine.for %arg1 = 0 to 8 { + affine.for %arg2 = 0 to 8 { + affine.for %arg3 = 0 to 64 { + memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + } + } + } + } + affine.for %arg0 = 0 to 1 { + affine.for %arg1 = 0 to 8 { + affine.for %arg2 = 0 to 8 { + affine.for %arg3 = 0 to 64 { + affine.for %arg4 = 0 to 1 { + affine.for %arg5 = 0 to 3 { + affine.for %arg6 = 0 to 3 { + affine.for %arg7 = 0 to 1 { + %2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + %21 = arith.addi %arg0, %arg4 : index + %22 = arith.constant 2 : index + %23 = arith.muli %arg1, %22 : index + %24 = arith.addi %23, %arg5 : index + %25 = arith.muli %arg2, %22 : index + %26 = arith.addi %25, %arg6 : index + %27 = arith.addi %arg3, %arg7 : index + %3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32> + %4 = arith.cmpf ogt, %2, %3 : f32 + %5 = arith.select %4, %2, %3 : f32 + memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + } + } + } + } + } + } + } + } + return +} + +// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32> +// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32> +// CHECK: affine.for %[[arg0:.*]] = 0 to 1 { +// CHECK: affine.for %[[arg1:.*]] = 0 to 8 { +// CHECK: affine.for %[[arg2:.*]] = 0 to 8 { +// CHECK: affine.for %[[arg3:.*]] = 0 to 64 { +// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: affine.for %[[a0:.*]] = 0 to 1 { +// CHECK: affine.for %[[a1:.*]] = 0 to 8 { +// CHECK: affine.for %[[a2:.*]] = 0 to 8 { +// CHECK: affine.for %[[a3:.*]] = 0 to 64 { +// CHECK: affine.for %[[a4:.*]] = 0 to 1 { +// CHECK: affine.for %[[a5:.*]] = 0 to 3 { +// CHECK: affine.for %[[a6:.*]] = 0 to 3 { +// CHECK: affine.for %[[a7:.*]] = 0 to 1 { +// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> +// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32> +// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32 +// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32 +// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } + +func.func @symbols(%N : index) { + %0 = memref.alloc() : memref<1024x1024xf32> + %1 = memref.alloc() : memref<1024x1024xf32> + %2 = memref.alloc() : memref<1024x1024xf32> + %cst1 = arith.constant 1 : index + %cst2 = arith.constant 2 : index + affine.for %i = 0 to %N { + affine.for %j = 0 to %N { + %7 = memref.load %2[%i, %j] : memref<1024x1024xf32> + %10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index { + %12 = arith.muli %N, %cst2 : index + %13 = arith.addi %12, %cst1 : index + %14 = arith.addi %13, %j : index + %5 = memref.load %0[%i, %12] : memref<1024x1024xf32> + %6 = memref.load %1[%14, %j] : memref<1024x1024xf32> + %8 = arith.mulf %5, %6 : f32 + %9 = arith.addf %7, %8 : f32 + %4 = arith.addi %N, %cst1 : index + %11 = arith.addi %ax, %cst1 : index + memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol + memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be lowered + %something = "ab.v"() : () -> index + memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered + affine.yield %11 : index + } + } + } + return +} + +// CHECK: %[[cst1:.*]] = arith.constant 1 : index +// CHECK: %[[v0:.*]] = memref.alloc() : memref< +// CHECK: %[[v1:.*]] = memref.alloc() : memref< +// CHECK: %[[v2:.*]] = memref.alloc() : memref< +// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 { +// CHECK-NEXT: affine.for %[[a2:.*]] = 0 to %arg0 { +// CHECK-NEXT: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32> +// CHECK-NEXT: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) { +// CHECK-NEXT: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32> +// CHECK-NEXT: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32> +// CHECK-NEXT: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]] +// CHECK-NEXT: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]] +// CHECK-NEXT: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]] +// CHECK-NEXT: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32> +// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32> +// CHECK-NEXT: %[[lhs7:.*]] = "ab.v" +// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32> +// CHECK-NEXT: affine.yield %[[lhs6]] From 8bdd3ac0177e6b0995f0ab985444ba0ecbb2a194 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= Date: Fri, 29 Nov 2024 13:50:04 +0100 Subject: [PATCH 2/4] Address review comments --- mlir/include/mlir/Dialect/Affine/Passes.td | 2 +- .../Affine/Transforms/DecomposeAffineOps.cpp | 11 --- .../Affine/Transforms/RaiseMemrefDialect.cpp | 92 ++++++++++--------- mlir/test/Dialect/Affine/raise-memref.mlir | 80 +++++++--------- 4 files changed, 83 insertions(+), 102 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td index 43ce9dae93441a..a77bcac5ed407f 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -406,7 +406,7 @@ def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> { They can be turned back to memref dialect ops with --lower-affine. }]; let constructor = "mlir::affine::createRaiseMemrefToAffine()"; - let dependentDialects = ["memref::MemRefDialect"]; + let dependentDialects = ["affine::AffineDialect"]; } def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> { diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp index 4d5ff5765ccc96..f28fb3acb7db7f 100644 --- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -13,20 +13,9 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/LogicalResult.h" -#include -#include -#include -#include using namespace mlir; using namespace mlir::affine; diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp index 2fd47549000001..a6e961a6d64390 100644 --- a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp @@ -1,29 +1,27 @@ - +//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements functionality to convert memref load and store ops to +// the corresponding affine ops, inferring the affine map as needed. +// +//===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/Analysis/Utils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/Dialect/Affine/Utils.h" -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/LogicalResult.h" -#include -#include -#include -#include -#include -#include namespace mlir { namespace affine { @@ -39,17 +37,24 @@ using namespace mlir::affine; namespace { +/// Find the index of the given value in the `dims` list, +/// and append it if it was not already in the list. The +/// dims list is a list of symbols or dimensions of the +/// affine map. Within the results of an affine map, they +/// are identified by their index, which is why we need +/// this function. static std::optional findInListOrAdd(Value value, llvm::SmallVectorImpl &dims, - const std::function &isValidElement) { + function_ref isValidElement) { Value *loopIV = std::find(dims.begin(), dims.end(), value); if (loopIV != dims.end()) { - // found an IV that already has an index + // We found an IV that already has an index, return that index. return {std::distance(dims.begin(), loopIV)}; } if (isValidElement(value)) { - // push this IV in the parameters + // This is a valid element for the dim/symbol list, push this as a + // parameter. size_t idx = dims.size(); dims.push_back(value); return idx; @@ -57,14 +62,15 @@ findInListOrAdd(Value value, llvm::SmallVectorImpl &dims, return std::nullopt; } -static LogicalResult toAffineExpr(Value value, AffineExpr &result, - llvm::SmallVectorImpl &affineDims, - llvm::SmallVectorImpl &affineSymbols) { +/// Convert a value to an affine expr if possible. Adds dims and symbols +/// if needed. +static AffineExpr toAffineExpr(Value value, + llvm::SmallVectorImpl &affineDims, + llvm::SmallVectorImpl &affineSymbols) { using namespace matchers; IntegerAttr::ValueType cst; if (matchPattern(value, m_ConstantInt(&cst))) { - result = getAffineConstantExpr(cst.getSExtValue(), value.getContext()); - return success(); + return getAffineConstantExpr(cst.getSExtValue(), value.getContext()); } Value lhs; Value rhs; @@ -72,48 +78,46 @@ static LogicalResult toAffineExpr(Value value, AffineExpr &result, matchPattern(value, m_Op(m_Any(&lhs), m_Any(&rhs)))) { AffineExpr lhsE; AffineExpr rhsE; - if (succeeded(toAffineExpr(lhs, lhsE, affineDims, affineSymbols)) && - succeeded(toAffineExpr(rhs, rhsE, affineDims, affineSymbols))) { + if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) && + (rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) { AffineExprKind kind; if (isa(value.getDefiningOp())) { kind = mlir::AffineExprKind::Add; } else { kind = mlir::AffineExprKind::Mul; } - result = getAffineBinaryOpExpr(kind, lhsE, rhsE); - return success(); + return getAffineBinaryOpExpr(kind, lhsE, rhsE); } } if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) { return affine::isValidSymbol(v); })) { - result = getAffineSymbolExpr(*dimIx, value.getContext()); - return success(); + return getAffineSymbolExpr(*dimIx, value.getContext()); } if (auto dimIx = findInListOrAdd( value, affineDims, [](Value v) { return affine::isValidDim(v); })) { - result = getAffineDimExpr(*dimIx, value.getContext()); - return success(); + return getAffineDimExpr(*dimIx, value.getContext()); } - return failure(); + return {}; } static LogicalResult computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map, llvm::SmallVectorImpl &mapArgs) { - llvm::SmallVector results; - llvm::SmallVector symbols; - llvm::SmallVector dims; + SmallVector results; + SmallVector symbols; + SmallVector dims; - for (auto indexExpr : indices) { - if (failed( - toAffineExpr(indexExpr, results.emplace_back(), dims, symbols))) { + for (Value indexExpr : indices) { + AffineExpr res = toAffineExpr(indexExpr, dims, symbols); + if (!res) { return failure(); } + results.push_back(res); } map = AffineMap::get(dims.size(), symbols.size(), results, ctx); @@ -140,21 +144,21 @@ struct RaiseMemrefDialect mapArgs))) { rewriter.replaceOpWithNewOp( op, store.getValueToStore(), store.getMemRef(), map, mapArgs); - } else { - LLVM_DEBUG(llvm::dbgs() - << "[affine] Cannot raise memref op: " << op << "\n"); + return; } - } else if (auto load = llvm::dyn_cast_or_null(op)) { + LLVM_DEBUG(llvm::dbgs() + << "[affine] Cannot raise memref op: " << op << "\n"); + } else if (auto load = llvm::dyn_cast_or_null(op)) { if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map, mapArgs))) { rewriter.replaceOpWithNewOp(op, load.getMemRef(), map, mapArgs); - } else { - LLVM_DEBUG(llvm::dbgs() - << "[affine] Cannot raise memref op: " << op << "\n"); + return; } + LLVM_DEBUG(llvm::dbgs() + << "[affine] Cannot raise memref op: " << op << "\n"); } }); } diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir index d529e2c0c907a6..d8562c13488d89 100644 --- a/mlir/test/Dialect/Affine/raise-memref.mlir +++ b/mlir/test/Dialect/Affine/raise-memref.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s -// CHECK-LABEL: func @reduce_window_max() { +// CHECK-LABEL: func @reduce_window_max( func.func @reduce_window_max() { %cst = arith.constant 0.000000e+00 : f32 %0 = memref.alloc() : memref<1x8x8x64xf32> @@ -45,41 +45,29 @@ func.func @reduce_window_max() { return } -// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[cst:.*]] = arith.constant 0 // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32> // CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32> -// CHECK: affine.for %[[arg0:.*]] = 0 to 1 { -// CHECK: affine.for %[[arg1:.*]] = 0 to 8 { -// CHECK: affine.for %[[arg2:.*]] = 0 to 8 { -// CHECK: affine.for %[[arg3:.*]] = 0 to 64 { -// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: affine.for %[[a0:.*]] = 0 to 1 { -// CHECK: affine.for %[[a1:.*]] = 0 to 8 { -// CHECK: affine.for %[[a2:.*]] = 0 to 8 { -// CHECK: affine.for %[[a3:.*]] = 0 to 64 { -// CHECK: affine.for %[[a4:.*]] = 0 to 1 { -// CHECK: affine.for %[[a5:.*]] = 0 to 3 { -// CHECK: affine.for %[[a6:.*]] = 0 to 3 { -// CHECK: affine.for %[[a7:.*]] = 0 to 1 { -// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> -// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32> +// CHECK: affine.for %[[arg0:.*]] = +// CHECK: affine.for %[[arg1:.*]] = +// CHECK: affine.for %[[arg2:.*]] = +// CHECK: affine.for %[[arg3:.*]] = +// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : +// CHECK: affine.for %[[a0:.*]] = +// CHECK: affine.for %[[a1:.*]] = +// CHECK: affine.for %[[a2:.*]] = +// CHECK: affine.for %[[a3:.*]] = +// CHECK: affine.for %[[a4:.*]] = +// CHECK: affine.for %[[a5:.*]] = +// CHECK: affine.for %[[a6:.*]] = +// CHECK: affine.for %[[a7:.*]] = +// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : +// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : // CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32 // CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32 -// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } +// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : +// CHECK-LABEL: func @symbols( func.func @symbols(%N : index) { %0 = memref.alloc() : memref<1024x1024xf32> %1 = memref.alloc() : memref<1024x1024xf32> @@ -100,9 +88,9 @@ func.func @symbols(%N : index) { %4 = arith.addi %N, %cst1 : index %11 = arith.addi %ax, %cst1 : index memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol - memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be lowered + memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be raised %something = "ab.v"() : () -> index - memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered + memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be raised affine.yield %11 : index } } @@ -115,16 +103,16 @@ func.func @symbols(%N : index) { // CHECK: %[[v1:.*]] = memref.alloc() : memref< // CHECK: %[[v2:.*]] = memref.alloc() : memref< // CHECK: affine.for %[[a1:.*]] = 0 to %arg0 { -// CHECK-NEXT: affine.for %[[a2:.*]] = 0 to %arg0 { -// CHECK-NEXT: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32> -// CHECK-NEXT: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) { -// CHECK-NEXT: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32> -// CHECK-NEXT: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32> -// CHECK-NEXT: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]] -// CHECK-NEXT: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]] -// CHECK-NEXT: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]] -// CHECK-NEXT: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32> -// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32> -// CHECK-NEXT: %[[lhs7:.*]] = "ab.v" -// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32> -// CHECK-NEXT: affine.yield %[[lhs6]] +// CHECK: affine.for %[[a2:.*]] = 0 to %arg0 { +// CHECK: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32> +// CHECK: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) { +// CHECK: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : +// CHECK: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : +// CHECK: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]] +// CHECK: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]] +// CHECK: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]] +// CHECK: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : +// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : +// CHECK: %[[lhs7:.*]] = "ab.v" +// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : +// CHECK: affine.yield %[[lhs6]] From a75c04eeade104671742adbff797f5a8bd8915b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= Date: Thu, 5 Dec 2024 10:32:21 +0100 Subject: [PATCH 3/4] Add todo comment --- mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp index a6e961a6d64390..f11646775903fe 100644 --- a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp @@ -78,6 +78,12 @@ static AffineExpr toAffineExpr(Value value, matchPattern(value, m_Op(m_Any(&lhs), m_Any(&rhs)))) { AffineExpr lhsE; AffineExpr rhsE; + // TODO: replace recursion with explicit stack. + // For the moment this can be tolerated as we only recurse on + // arith.addi and arith.muli, so there cannot be any infinite + // recursion. The depth of these expressions should be in most + // cases very manageable, as affine expressions should be as + // simple as `a + b * c`. if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) && (rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) { AffineExprKind kind; From 09e3ef510355da1b1a25537a464011e8a1f3cfe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= Date: Thu, 5 Dec 2024 11:29:14 +0100 Subject: [PATCH 4/4] Add test for when the accesses are not affine expressions --- .../Affine/Transforms/RaiseMemrefDialect.cpp | 27 ++++++++++++------- mlir/test/Dialect/Affine/raise-memref.mlir | 20 ++++++++++++++ 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp index f11646775903fe..491d2e03c36bca 100644 --- a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" namespace mlir { @@ -72,28 +73,36 @@ static AffineExpr toAffineExpr(Value value, if (matchPattern(value, m_ConstantInt(&cst))) { return getAffineConstantExpr(cst.getSExtValue(), value.getContext()); } - Value lhs; - Value rhs; - if (matchPattern(value, m_Op(m_Any(&lhs), m_Any(&rhs))) || - matchPattern(value, m_Op(m_Any(&lhs), m_Any(&rhs)))) { - AffineExpr lhsE; - AffineExpr rhsE; + + Operation *definingOp = value.getDefiningOp(); + if (llvm::isa_and_nonnull(definingOp) || + llvm::isa_and_nonnull(definingOp)) { // TODO: replace recursion with explicit stack. // For the moment this can be tolerated as we only recurse on // arith.addi and arith.muli, so there cannot be any infinite // recursion. The depth of these expressions should be in most // cases very manageable, as affine expressions should be as // simple as `a + b * c`. - if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) && - (rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) { + AffineExpr lhsE = + toAffineExpr(definingOp->getOperand(0), affineDims, affineSymbols); + AffineExpr rhsE = + toAffineExpr(definingOp->getOperand(1), affineDims, affineSymbols); + + if (lhsE && rhsE) { AffineExprKind kind; - if (isa(value.getDefiningOp())) { + if (isa(definingOp)) { kind = mlir::AffineExprKind::Add; } else { kind = mlir::AffineExprKind::Mul; + + if (!lhsE.isSymbolicOrConstant() && !rhsE.isSymbolicOrConstant()) { + // This is not an affine expression, give up. + return {}; + } } return getAffineBinaryOpExpr(kind, lhsE, rhsE); } + return {}; } if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) { diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir index d8562c13488d89..00cc98de1f40fb 100644 --- a/mlir/test/Dialect/Affine/raise-memref.mlir +++ b/mlir/test/Dialect/Affine/raise-memref.mlir @@ -116,3 +116,23 @@ func.func @symbols(%N : index) { // CHECK: %[[lhs7:.*]] = "ab.v" // CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : // CHECK: affine.yield %[[lhs6]] + + +// CHECK-LABEL: func @non_affine( +func.func @non_affine(%N : index) { + %2 = memref.alloc() : memref<1024x1024xf32> + affine.for %i = 0 to %N { + affine.for %j = 0 to %N { + %ij = arith.muli %i, %j : index + %7 = memref.load %2[%i, %ij] : memref<1024x1024xf32> + memref.store %7, %2[%ij, %ij] : memref<1024x1024xf32> + } + } + return +} + +// CHECK: affine.for %[[i:.*]] = +// CHECK: affine.for %[[j:.*]] = +// CHECK: %[[ij:.*]] = arith.muli %[[i]], %[[j]] +// CHECK: %[[v:.*]] = memref.load %{{.*}}[%[[i]], %[[ij]]] +// CHECK: memref.store %[[v]], %{{.*}}[%[[ij]], %[[ij]]] \ No newline at end of file