-
Notifications
You must be signed in to change notification settings - Fork 12.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Restrict isValidDim to induction vars, and not iter_args
- Loading branch information
Showing
7 changed files
with
335 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
168 changes: 168 additions & 0 deletions
168
mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
|
||
|
||
#include "mlir/Dialect/Affine/Analysis/Utils.h" | ||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Affine/Passes.h" | ||
#include "mlir/Dialect/Affine/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Affine/Utils.h" | ||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/IR/AffineExpr.h" | ||
#include "mlir/IR/MLIRContext.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/IR/Operation.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/Support/Casting.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "llvm/Support/LogicalResult.h" | ||
#include <algorithm> | ||
#include <cstddef> | ||
#include <functional> | ||
#include <iterator> | ||
#include <memory> | ||
#include <optional> | ||
|
||
namespace mlir { | ||
namespace affine { | ||
#define GEN_PASS_DEF_RAISEMEMREFDIALECT | ||
#include "mlir/Dialect/Affine/Passes.h.inc" | ||
} // namespace affine | ||
} // namespace mlir | ||
|
||
#define DEBUG_TYPE "raise-memref-to-affine" | ||
|
||
using namespace mlir; | ||
using namespace mlir::affine; | ||
|
||
namespace { | ||
|
||
static std::optional<size_t> | ||
findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims, | ||
const std::function<bool(Value)> &isValidElement) { | ||
|
||
Value *loopIV = std::find(dims.begin(), dims.end(), value); | ||
if (loopIV != dims.end()) { | ||
// found an IV that already has an index | ||
return {std::distance(dims.begin(), loopIV)}; | ||
} | ||
if (isValidElement(value)) { | ||
// push this IV in the parameters | ||
size_t idx = dims.size(); | ||
dims.push_back(value); | ||
return idx; | ||
} | ||
return std::nullopt; | ||
} | ||
|
||
static LogicalResult toAffineExpr(Value value, AffineExpr &result, | ||
llvm::SmallVectorImpl<Value> &affineDims, | ||
llvm::SmallVectorImpl<Value> &affineSymbols) { | ||
using namespace matchers; | ||
IntegerAttr::ValueType cst; | ||
if (matchPattern(value, m_ConstantInt(&cst))) { | ||
result = getAffineConstantExpr(cst.getSExtValue(), value.getContext()); | ||
return success(); | ||
} | ||
Value lhs; | ||
Value rhs; | ||
if (matchPattern(value, m_Op<arith::AddIOp>(m_Any(&lhs), m_Any(&rhs))) || | ||
matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) { | ||
AffineExpr lhsE; | ||
AffineExpr rhsE; | ||
if (succeeded(toAffineExpr(lhs, lhsE, affineDims, affineSymbols)) && | ||
succeeded(toAffineExpr(rhs, rhsE, affineDims, affineSymbols))) { | ||
AffineExprKind kind; | ||
if (isa<arith::AddIOp>(value.getDefiningOp())) { | ||
kind = mlir::AffineExprKind::Add; | ||
} else { | ||
kind = mlir::AffineExprKind::Mul; | ||
} | ||
result = getAffineBinaryOpExpr(kind, lhsE, rhsE); | ||
return success(); | ||
} | ||
} | ||
|
||
if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) { | ||
return affine::isValidSymbol(v); | ||
})) { | ||
result = getAffineSymbolExpr(*dimIx, value.getContext()); | ||
return success(); | ||
} | ||
|
||
if (auto dimIx = findInListOrAdd( | ||
value, affineDims, [](Value v) { return affine::isValidDim(v); })) { | ||
|
||
result = getAffineDimExpr(*dimIx, value.getContext()); | ||
return success(); | ||
} | ||
|
||
return failure(); | ||
} | ||
|
||
static LogicalResult | ||
computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map, | ||
llvm::SmallVectorImpl<Value> &mapArgs) { | ||
llvm::SmallVector<AffineExpr> results; | ||
llvm::SmallVector<Value, 2> symbols; | ||
llvm::SmallVector<Value, 8> dims; | ||
|
||
for (auto indexExpr : indices) { | ||
if (failed( | ||
toAffineExpr(indexExpr, results.emplace_back(), dims, symbols))) { | ||
return failure(); | ||
} | ||
} | ||
|
||
map = AffineMap::get(dims.size(), symbols.size(), results, ctx); | ||
|
||
dims.append(symbols); | ||
mapArgs.swap(dims); | ||
return success(); | ||
} | ||
|
||
struct RaiseMemrefDialect | ||
: public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> { | ||
|
||
void runOnOperation() override { | ||
auto *ctx = &getContext(); | ||
Operation *op = getOperation(); | ||
IRRewriter rewriter(ctx); | ||
AffineMap map; | ||
SmallVector<Value> mapArgs; | ||
op->walk([&](Operation *op) { | ||
rewriter.setInsertionPoint(op); | ||
if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) { | ||
|
||
if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map, | ||
mapArgs))) { | ||
rewriter.replaceOpWithNewOp<AffineStoreOp>( | ||
op, store.getValueToStore(), store.getMemRef(), map, mapArgs); | ||
} else { | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "[affine] Cannot raise memref op: " << op << "\n"); | ||
} | ||
|
||
} else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) { | ||
|
||
if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map, | ||
mapArgs))) { | ||
rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map, | ||
mapArgs); | ||
} else { | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "[affine] Cannot raise memref op: " << op << "\n"); | ||
} | ||
} | ||
}); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> | ||
mlir::affine::createRaiseMemrefToAffine() { | ||
return std::make_unique<RaiseMemrefDialect>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s | ||
|
||
// CHECK-LABEL: func @reduce_window_max() { | ||
func.func @reduce_window_max() { | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%0 = memref.alloc() : memref<1x8x8x64xf32> | ||
%1 = memref.alloc() : memref<1x18x18x64xf32> | ||
affine.for %arg0 = 0 to 1 { | ||
affine.for %arg1 = 0 to 8 { | ||
affine.for %arg2 = 0 to 8 { | ||
affine.for %arg3 = 0 to 64 { | ||
memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> | ||
} | ||
} | ||
} | ||
} | ||
affine.for %arg0 = 0 to 1 { | ||
affine.for %arg1 = 0 to 8 { | ||
affine.for %arg2 = 0 to 8 { | ||
affine.for %arg3 = 0 to 64 { | ||
affine.for %arg4 = 0 to 1 { | ||
affine.for %arg5 = 0 to 3 { | ||
affine.for %arg6 = 0 to 3 { | ||
affine.for %arg7 = 0 to 1 { | ||
%2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> | ||
%21 = arith.addi %arg0, %arg4 : index | ||
%22 = arith.constant 2 : index | ||
%23 = arith.muli %arg1, %22 : index | ||
%24 = arith.addi %23, %arg5 : index | ||
%25 = arith.muli %arg2, %22 : index | ||
%26 = arith.addi %25, %arg6 : index | ||
%27 = arith.addi %arg3, %arg7 : index | ||
%3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32> | ||
%4 = arith.cmpf ogt, %2, %3 : f32 | ||
%5 = arith.select %4, %2, %3 : f32 | ||
memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
return | ||
} | ||
|
||
// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 | ||
// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32> | ||
// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32> | ||
// CHECK: affine.for %[[arg0:.*]] = 0 to 1 { | ||
// CHECK: affine.for %[[arg1:.*]] = 0 to 8 { | ||
// CHECK: affine.for %[[arg2:.*]] = 0 to 8 { | ||
// CHECK: affine.for %[[arg3:.*]] = 0 to 64 { | ||
// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32> | ||
// CHECK: } | ||
// CHECK: } | ||
// CHECK: } | ||
// CHECK: } | ||
// CHECK: affine.for %[[a0:.*]] = 0 to 1 { | ||
// CHECK: affine.for %[[a1:.*]] = 0 to 8 { | ||
// CHECK: affine.for %[[a2:.*]] = 0 to 8 { | ||
// CHECK: affine.for %[[a3:.*]] = 0 to 64 { | ||
// CHECK: affine.for %[[a4:.*]] = 0 to 1 { | ||
// CHECK: affine.for %[[a5:.*]] = 0 to 3 { | ||
// CHECK: affine.for %[[a6:.*]] = 0 to 3 { | ||
// CHECK: affine.for %[[a7:.*]] = 0 to 1 { | ||
// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> | ||
// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32> | ||
// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32 | ||
// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32 | ||
// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> | ||
// CHECK: } | ||
// CHECK: } | ||
// CHECK: } | ||
// CHECK: } | ||
// CHECK: } | ||
// CHECK: } | ||
// CHECK: } | ||
// CHECK: } | ||
// CHECK: } | ||
|
||
func.func @symbols(%N : index) { | ||
%0 = memref.alloc() : memref<1024x1024xf32> | ||
%1 = memref.alloc() : memref<1024x1024xf32> | ||
%2 = memref.alloc() : memref<1024x1024xf32> | ||
%cst1 = arith.constant 1 : index | ||
%cst2 = arith.constant 2 : index | ||
affine.for %i = 0 to %N { | ||
affine.for %j = 0 to %N { | ||
%7 = memref.load %2[%i, %j] : memref<1024x1024xf32> | ||
%10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index { | ||
%12 = arith.muli %N, %cst2 : index | ||
%13 = arith.addi %12, %cst1 : index | ||
%14 = arith.addi %13, %j : index | ||
%5 = memref.load %0[%i, %12] : memref<1024x1024xf32> | ||
%6 = memref.load %1[%14, %j] : memref<1024x1024xf32> | ||
%8 = arith.mulf %5, %6 : f32 | ||
%9 = arith.addf %7, %8 : f32 | ||
%4 = arith.addi %N, %cst1 : index | ||
%11 = arith.addi %ax, %cst1 : index | ||
memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol | ||
memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be lowered | ||
%something = "ab.v"() : () -> index | ||
memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered | ||
affine.yield %11 : index | ||
} | ||
} | ||
} | ||
return | ||
} | ||
|
||
// CHECK: %[[cst1:.*]] = arith.constant 1 : index | ||
// CHECK: %[[v0:.*]] = memref.alloc() : memref< | ||
// CHECK: %[[v1:.*]] = memref.alloc() : memref< | ||
// CHECK: %[[v2:.*]] = memref.alloc() : memref< | ||
// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 { | ||
// CHECK-NEXT: affine.for %[[a2:.*]] = 0 to %arg0 { | ||
// CHECK-NEXT: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32> | ||
// CHECK-NEXT: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) { | ||
// CHECK-NEXT: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32> | ||
// CHECK-NEXT: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32> | ||
// CHECK-NEXT: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]] | ||
// CHECK-NEXT: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]] | ||
// CHECK-NEXT: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]] | ||
// CHECK-NEXT: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32> | ||
// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32> | ||
// CHECK-NEXT: %[[lhs7:.*]] = "ab.v" | ||
// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32> | ||
// CHECK-NEXT: affine.yield %[[lhs6]] |