From 99b600f729c32c566ed5a819b85de252d8c64bbb Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:19:59 -0500 Subject: [PATCH] [Codegen] Allow padding of dynamic allocas (#19399) This PR adds support for padding for allocas in the PadDynamicAllocsPass. The padding works the same for alloca as for alloc. --------- Signed-off-by: Max Dawkins --- .../Codegen/Common/PadDynamicAlloc.cpp | 21 +++++++++++++------ .../Common/test/pad_dynamic_alloc.mlir | 12 ++++++++++- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp index db1819b017e6..2a9d2d39bd75 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp @@ -65,7 +65,8 @@ static FailureOr getUpperBound(Value dim, return failure(); } -static LogicalResult padAlloc(MLIRContext *context, memref::AllocOp allocOp, +template +static LogicalResult padAlloc(MLIRContext *context, AllocLikeOp allocOp, const DataFlowSolver &solver) { IRRewriter rewriter(context); rewriter.setInsertionPoint(allocOp); @@ -94,7 +95,7 @@ static LogicalResult padAlloc(MLIRContext *context, memref::AllocOp allocOp, MemRefType allocType = MemRefType::get(shape, elType, AffineMap(), allocOp.getType().getMemorySpace()); Location loc = allocOp.getLoc(); - Value paddedAlloc = rewriter.create(loc, allocType); + Value paddedAlloc = rewriter.create(loc, allocType); SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); Value subview = rewriter.create(loc, paddedAlloc, offsets, @@ -111,7 +112,6 @@ struct PadDynamicAllocPass final void runOnOperation() override { auto funcOp = getOperation(); MLIRContext *context = &getContext(); - SmallVector sharedMemAllocs; DataFlowSolver solver; solver.load(); @@ -122,12 +122,21 @@ struct PadDynamicAllocPass final } // Collect all the alloc operations. - funcOp.walk( - [&](memref::AllocOp allocOp) { sharedMemAllocs.push_back(allocOp); }); - for (memref::AllocOp alloc : sharedMemAllocs) { + SmallVector allocs; + funcOp.walk([&](memref::AllocOp allocOp) { allocs.push_back(allocOp); }); + for (memref::AllocOp alloc : allocs) { if (failed(padAlloc(context, alloc, solver))) return signalPassFailure(); } + + // Collect all the alloca operations. + SmallVector allocas; + funcOp.walk( + [&](memref::AllocaOp allocaOp) { allocas.push_back(allocaOp); }); + for (memref::AllocaOp alloca : allocas) { + if (failed(padAlloc(context, alloca, solver))) + return signalPassFailure(); + } } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir index 0b56bd2cb963..e9d4d7b82181 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir @@ -37,4 +37,14 @@ func.func @dynamic_bound_alloc(%id : index) { return } // CHECK-LABEL: func @dynamic_bound_alloc( -// CHECK: %alloc = memref.alloc() : memref<4088xf32, 3> +// CHECK: memref.alloc() : memref<4088xf32, 3> + +// ----- + +func.func @dynamic_bound_alloca(%id : index) { + %0 = util.assume.int %id : index + %1 = memref.alloca(%0) : memref + return +} +// CHECK-LABEL: func @dynamic_bound_alloca( +// CHECK: memref.alloca() : memref<4088xf32, 3>