Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
oowekyala committed Nov 29, 2024
1 parent 6127ad9 commit 6eb2458
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 100 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Affine/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
They can be turned back to memref dialect ops with --lower-affine.
}];
let constructor = "mlir::affine::createRaiseMemrefToAffine()";
let dependentDialects = ["memref::MemRefDialect"];
let dependentDialects = ["affine::AffineDialect"];
}

def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
Expand Down
11 changes: 0 additions & 11 deletions mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,9 @@

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iterator>

using namespace mlir;
using namespace mlir::affine;
Expand Down
92 changes: 48 additions & 44 deletions mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@

//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements functionality to convert memref load and store ops to
// the corresponding affine ops, inferring the affine map as needed.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iterator>
#include <memory>
#include <optional>

namespace mlir {
namespace affine {
Expand All @@ -39,81 +37,87 @@ using namespace mlir::affine;

namespace {

/// Find the index of the given value in the `dims` list,
/// and append it if it was not already in the list. The
/// dims list is a list of symbols or dimensions of the
/// affine map. Within the results of an affine map, they
/// are identified by their index, which is why we need
/// this function.
static std::optional<size_t>
findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
const std::function<bool(Value)> &isValidElement) {
function_ref<bool(Value)> isValidElement) {

Value *loopIV = std::find(dims.begin(), dims.end(), value);
if (loopIV != dims.end()) {
// found an IV that already has an index
// We found an IV that already has an index, return that index.
return {std::distance(dims.begin(), loopIV)};
}
if (isValidElement(value)) {
// push this IV in the parameters
// This is a valid element for the dim/symbol list, push this as a
// parameter.
size_t idx = dims.size();
dims.push_back(value);
return idx;
}
return std::nullopt;
}

static LogicalResult toAffineExpr(Value value, AffineExpr &result,
llvm::SmallVectorImpl<Value> &affineDims,
llvm::SmallVectorImpl<Value> &affineSymbols) {
/// Convert a value to an affine expr if possible. Adds dims and symbols
/// if needed.
static AffineExpr toAffineExpr(Value value,
llvm::SmallVectorImpl<Value> &affineDims,
llvm::SmallVectorImpl<Value> &affineSymbols) {
using namespace matchers;
IntegerAttr::ValueType cst;
if (matchPattern(value, m_ConstantInt(&cst))) {
result = getAffineConstantExpr(cst.getSExtValue(), value.getContext());
return success();
return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
}
Value lhs;
Value rhs;
if (matchPattern(value, m_Op<arith::AddIOp>(m_Any(&lhs), m_Any(&rhs))) ||
matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
AffineExpr lhsE;
AffineExpr rhsE;
if (succeeded(toAffineExpr(lhs, lhsE, affineDims, affineSymbols)) &&
succeeded(toAffineExpr(rhs, rhsE, affineDims, affineSymbols))) {
if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) &&
(rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) {
AffineExprKind kind;
if (isa<arith::AddIOp>(value.getDefiningOp())) {
kind = mlir::AffineExprKind::Add;
} else {
kind = mlir::AffineExprKind::Mul;
}
result = getAffineBinaryOpExpr(kind, lhsE, rhsE);
return success();
return getAffineBinaryOpExpr(kind, lhsE, rhsE);
}
}

if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
return affine::isValidSymbol(v);
})) {
result = getAffineSymbolExpr(*dimIx, value.getContext());
return success();
return getAffineSymbolExpr(*dimIx, value.getContext());
}

if (auto dimIx = findInListOrAdd(
value, affineDims, [](Value v) { return affine::isValidDim(v); })) {

result = getAffineDimExpr(*dimIx, value.getContext());
return success();
return getAffineDimExpr(*dimIx, value.getContext());
}

return failure();
return {};
}

static LogicalResult
computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
llvm::SmallVectorImpl<Value> &mapArgs) {
llvm::SmallVector<AffineExpr> results;
llvm::SmallVector<Value, 2> symbols;
llvm::SmallVector<Value, 8> dims;
SmallVector<AffineExpr> results;
SmallVector<Value> symbols;
SmallVector<Value> dims;

for (auto indexExpr : indices) {
if (failed(
toAffineExpr(indexExpr, results.emplace_back(), dims, symbols))) {
for (Value indexExpr : indices) {
AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
if (!res) {
return failure();
}
results.push_back(res);
}

map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
Expand All @@ -140,21 +144,21 @@ struct RaiseMemrefDialect
mapArgs))) {
rewriter.replaceOpWithNewOp<AffineStoreOp>(
op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
} else {
LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise memref op: " << op << "\n");
return;
}

} else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise memref op: " << op << "\n");

} else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
mapArgs))) {
rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
mapArgs);
} else {
LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise memref op: " << op << "\n");
return;
}
LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise memref op: " << op << "\n");
}
});
}
Expand Down
76 changes: 32 additions & 44 deletions mlir/test/Dialect/Affine/raise-memref.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s

// CHECK-LABEL: func @reduce_window_max() {
// CHECK-LABEL: func @reduce_window_max(
func.func @reduce_window_max() {
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.alloc() : memref<1x8x8x64xf32>
Expand Down Expand Up @@ -48,38 +48,26 @@ func.func @reduce_window_max() {
// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
// CHECK: affine.for %[[arg0:.*]] = 0 to 1 {
// CHECK: affine.for %[[arg1:.*]] = 0 to 8 {
// CHECK: affine.for %[[arg2:.*]] = 0 to 8 {
// CHECK: affine.for %[[arg3:.*]] = 0 to 64 {
// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: affine.for %[[a0:.*]] = 0 to 1 {
// CHECK: affine.for %[[a1:.*]] = 0 to 8 {
// CHECK: affine.for %[[a2:.*]] = 0 to 8 {
// CHECK: affine.for %[[a3:.*]] = 0 to 64 {
// CHECK: affine.for %[[a4:.*]] = 0 to 1 {
// CHECK: affine.for %[[a5:.*]] = 0 to 3 {
// CHECK: affine.for %[[a6:.*]] = 0 to 3 {
// CHECK: affine.for %[[a7:.*]] = 0 to 1 {
// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32>
// CHECK: affine.for %[[arg0:.*]]
// CHECK: affine.for %[[arg1:.*]]
// CHECK: affine.for %[[arg2:.*]]
// CHECK: affine.for %[[arg3:.*]]
// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] :
// CHECK: affine.for %[[a0:.*]]
// CHECK: affine.for %[[a1:.*]]
// CHECK: affine.for %[[a2:.*]]
// CHECK: affine.for %[[a3:.*]]
// CHECK: affine.for %[[a4:.*]]
// CHECK: affine.for %[[a5:.*]]
// CHECK: affine.for %[[a6:.*]]
// CHECK: affine.for %[[a7:.*]]
// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] :
// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :

// CHECK-LABEL: func @symbols(
func.func @symbols(%N : index) {
%0 = memref.alloc() : memref<1024x1024xf32>
%1 = memref.alloc() : memref<1024x1024xf32>
Expand All @@ -100,7 +88,7 @@ func.func @symbols(%N : index) {
%4 = arith.addi %N, %cst1 : index
%11 = arith.addi %ax, %cst1 : index
memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol
memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be lowered
memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be raised
%something = "ab.v"() : () -> index
memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered
affine.yield %11 : index
Expand All @@ -115,16 +103,16 @@ func.func @symbols(%N : index) {
// CHECK: %[[v1:.*]] = memref.alloc() : memref<
// CHECK: %[[v2:.*]] = memref.alloc() : memref<
// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 {
// CHECK-NEXT: affine.for %[[a2:.*]] = 0 to %arg0 {
// CHECK-NEXT: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
// CHECK-NEXT: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
// CHECK-NEXT: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32>
// CHECK-NEXT: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32>
// CHECK-NEXT: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
// CHECK-NEXT: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
// CHECK-NEXT: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
// CHECK-NEXT: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32>
// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32>
// CHECK-NEXT: %[[lhs7:.*]] = "ab.v"
// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32>
// CHECK-NEXT: affine.yield %[[lhs6]]
// CHECK: affine.for %[[a2:.*]] = 0 to %arg0 {
// CHECK: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
// CHECK: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
// CHECK: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32>
// CHECK: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32>
// CHECK: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
// CHECK: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
// CHECK: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
// CHECK: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32>
// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32>
// CHECK: %[[lhs7:.*]] = "ab.v"
// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32>
// CHECK: affine.yield %[[lhs6]]

0 comments on commit 6eb2458

Please sign in to comment.