Skip to content

Commit

Permalink
[Codegen] Allow padding of dynamic allocas (#19399)
Browse files Browse the repository at this point in the history
This PR adds support for padding for allocas in the
PadDynamicAllocsPass. The padding works the same for alloca as for
alloc.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Dec 13, 2024
1 parent 8a7b754 commit 99b600f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
21 changes: 15 additions & 6 deletions compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ static FailureOr<int64_t> getUpperBound(Value dim,
return failure();
}

static LogicalResult padAlloc(MLIRContext *context, memref::AllocOp allocOp,
template <typename AllocLikeOp>
static LogicalResult padAlloc(MLIRContext *context, AllocLikeOp allocOp,
const DataFlowSolver &solver) {
IRRewriter rewriter(context);
rewriter.setInsertionPoint(allocOp);
Expand Down Expand Up @@ -94,7 +95,7 @@ static LogicalResult padAlloc(MLIRContext *context, memref::AllocOp allocOp,
MemRefType allocType = MemRefType::get(shape, elType, AffineMap(),
allocOp.getType().getMemorySpace());
Location loc = allocOp.getLoc();
Value paddedAlloc = rewriter.create<memref::AllocOp>(loc, allocType);
Value paddedAlloc = rewriter.create<AllocLikeOp>(loc, allocType);
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
Value subview = rewriter.create<memref::SubViewOp>(loc, paddedAlloc, offsets,
Expand All @@ -111,7 +112,6 @@ struct PadDynamicAllocPass final
void runOnOperation() override {
auto funcOp = getOperation();
MLIRContext *context = &getContext();
SmallVector<memref::AllocOp> sharedMemAllocs;

DataFlowSolver solver;
solver.load<dataflow::DeadCodeAnalysis>();
Expand All @@ -122,12 +122,21 @@ struct PadDynamicAllocPass final
}

// Collect all the alloc operations.
funcOp.walk(
[&](memref::AllocOp allocOp) { sharedMemAllocs.push_back(allocOp); });
for (memref::AllocOp alloc : sharedMemAllocs) {
SmallVector<memref::AllocOp> allocs;
funcOp.walk([&](memref::AllocOp allocOp) { allocs.push_back(allocOp); });
for (memref::AllocOp alloc : allocs) {
if (failed(padAlloc(context, alloc, solver)))
return signalPassFailure();
}

// Collect all the alloca operations.
SmallVector<memref::AllocaOp> allocas;
funcOp.walk(
[&](memref::AllocaOp allocaOp) { allocas.push_back(allocaOp); });
for (memref::AllocaOp alloca : allocas) {
if (failed(padAlloc(context, alloca, solver)))
return signalPassFailure();
}
}
};
} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,14 @@ func.func @dynamic_bound_alloc(%id : index) {
return
}
// CHECK-LABEL: func @dynamic_bound_alloc(
// CHECK: %alloc = memref.alloc() : memref<4088xf32, 3>
// CHECK: memref.alloc() : memref<4088xf32, 3>

// -----

func.func @dynamic_bound_alloca(%id : index) {
%0 = util.assume.int %id<umin = 0, umax = 4088> : index
%1 = memref.alloca(%0) : memref<?xf32, 3>
return
}
// CHECK-LABEL: func @dynamic_bound_alloca(
// CHECK: memref.alloca() : memref<4088xf32, 3>

0 comments on commit 99b600f

Please sign in to comment.