-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref][transform] Add new alloca_to_global op. #66511
[mlir][memref][transform] Add new alloca_to_global op. #66511
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir-core ChangesThis PR adds a new transform op that replaces `memref.alloca`s with `memref.get_global`s to newly inserted `memref.global`s. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals. -- Full diff: https://github.com//pull/66511.diff7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td index 681759f970cb910..6a78784d74dd53c 100644 --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -144,6 +144,71 @@ def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op<Transform_Dialect, } def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">; +def Transform_MemRefAllocaOp : Transform_ConcreteOpType<"memref.alloca">; + +def MemRefAllocaToGlobalOp : + Op<Transform_Dialect, "memref.alloca_to_global", + [TransformOpInterface, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + DeclareOpInterfaceMethods<TransformOpInterface>]> { + let description = [{ + Inserts a new `memref.global` for each provided `memref.alloca` into the + provided module and replaces it with a `memref.get_global`. This is useful, + for example, for allocations that should reside in the shared memory of + a GPU, which have to be declared as globals. + + #### Example + + Consider the following transform op: + + ```mlir + %get_global, %global = + transform.memref.alloca_to_global %alloca in %module + : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + -> (!transform.any_op, !transform.any_op) + ``` + + and the following input payload: + + ```mlir + module { + func.func @func() { + %alloca = memref.alloca() : memref<2x32xf32> + // usages of %alloca... + } + } + ``` + + then applying the transform op to the payload would result in the following + output IR: + + ```mlir + module { + memref.global "private" @alloc : memref<2x32xf32> + func.func @func() { + %alloca = memref.get_global @alloc : memref<2x32xf32> + // usages of %alloca... + } + } + ``` + + #### Return modes + + Emits a definite failure if not exactly one `module` payload op was provided + or any of the `alloca` payload ops is not inside that module, and succeeds + otherwise. The returned handles refer to the `memref.get_global` and + `memref.global` ops that were inserted by the transformation. + }]; + + let arguments = (ins Transform_ConcreteOpType<"builtin.module">:$module, + Transform_MemRefAllocaOp:$alloca); + let results = (outs TransformHandleTypeInterface:$get_global, + TransformHandleTypeInterface:$global); + + let assemblyFormat = [{ + $alloca `in` $module attr-dict `:` functional-type(operands, results) + }]; +} def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer", [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 58f4d8d8f6d21fe..7467359da83c37f 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -126,6 +126,96 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp:: memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); } +//===----------------------------------------------------------------------===// +// AllocaToGlobalOp +//===----------------------------------------------------------------------===// + +namespace { +static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix, + ModuleOp module) { + llvm::SmallString<64> candidateNameStorage; + StringRef candidateName(prefix); + int uniqueNumber = 0; + while (true) { + if (!module.lookupSymbol(candidateName)) { + break; + } + candidateNameStorage.clear(); + candidateName = (prefix + Twine("_") + Twine(uniqueNumber)) + .toStringRef(candidateNameStorage); + uniqueNumber++; + } + return candidateName; +} +} // namespace + +DiagnosedSilenceableFailure +transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto allocaOps = state.getPayloadOps(getAlloca()); + + SmallVector<memref::GlobalOp> globalOps; + SmallVector<memref::GetGlobalOp> getGlobalOps; + + // Get `builtin.module`. + auto moduleOps = state.getPayloadOps(getModule()); + if (!llvm::hasSingleElement(moduleOps)) { + return emitDefiniteFailure() + << Twine("expected exactly one 'module' payload, but found ") + + std::to_string(llvm::range_size(moduleOps)); + } + ModuleOp module = cast<ModuleOp>(*moduleOps.begin()); + + // Transform `memref.alloca`s. + for (auto *op : allocaOps) { + auto alloca = cast<memref::AllocaOp>(op); + MLIRContext *ctx = rewriter.getContext(); + Location loc = alloca->getLoc(); + + memref::GlobalOp globalOp; + { + // Insert a `memref.global` at the beginning of the module. + if (module != alloca->getParentOfType<ModuleOp>()) { + return emitDefiniteFailure() + << "expected 'alloca' payload to be inside 'module' payload"; + } + IRRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module.getBodyRegion().front()); + Type resultType = alloca.getResult().getType(); + llvm::SmallString<64> symName = getUniqueSymbol("alloca", module); + // XXX: Add a better builder for this. + globalOp = rewriter.create<memref::GlobalOp>( + loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"), + TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); + } + + // Replace the `memref.alloca` with a `memref.get_global` accessing the + // global symbol inserted above. + rewriter.setInsertionPoint(alloca); + auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>( + alloca, globalOp.getType(), globalOp.getName()); + + globalOps.push_back(globalOp); + getGlobalOps.push_back(getGlobalOp); + } + + // Assemble results. + results.set(getGlobal().cast<OpResult>(), globalOps); + results.set(getGetGlobal().cast<OpResult>(), getGlobalOps); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::MemRefAllocaToGlobalOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getModule(), effects); + producesHandle(getGlobal(), effects); + producesHandle(getGetGlobal(), effects); + consumesHandle(getAlloca(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // MemRefMultiBufferOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index de3cd1b28e435bc..f1d07b85adb7576 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1233,7 +1233,7 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter, DenseSet<Operation *> resultSet; for (Operation *target : state.getPayloadOps(getTarget())) { Operation *parent = target->getParentOp(); - do { + while (parent) { bool checkIsolatedFromAbove = !getIsolatedFromAbove() || parent->hasTrait<OpTrait::IsIsolatedFromAbove>(); @@ -1241,7 +1241,8 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter, parent->getName().getStringRef() == *getOpName(); if (checkIsolatedFromAbove && checkOpName) break; - } while ((parent = parent->getParentOp())); + parent = parent->getParentOp(); + } if (!parent) { DiagnosedSilenceableFailure diag = emitSilenceableError() diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py index 4afe8e7b887f68e..56dcfbe5655e9b6 100644 --- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py @@ -11,6 +11,64 @@ from typing import Optional, overload, Union +class MemRefAllocaToGlobalOp: + """Specialization for MemRefAllocaToGlobalOp class.""" + + @overload + def __init__( + self, + get_global_type: Type, + global_type: Type, + module: Union[Operation, OpView, Value], + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None + ): + ... + + @overload + def __init__( + self, + module: Union[Operation, OpView, Value], + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None + ): + ... + + def __init__( + self, + get_global_type_or_module: Union[Operation, OpView, Type, Value], + global_type_or_alloca: Union[Operation, OpView, Type, Value], + module_or_none: Optional[Union[Operation, OpView, Value]] = None, + alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None + ): + if isinstance(get_global_type_or_module, Type): + get_global_type = get_global_type_or_module + global_type = global_type_or_alloca + module = module_or_none + alloca = alloca_or_none + else: + get_global_type = transform.AnyOpType.get() + global_type = transform.AnyOpType.get() + module = get_global_type_or_module + alloca = global_type_or_alloca + + super().__init__( + get_global_type, + global_type, + module, + alloca, + loc=loc, + ip=ip, + ) + + class MemRefMultiBufferOp: """Specialization for MemRefMultiBufferOp class.""" diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index b19db447af1c28a..aeeb2a6b0abedc5 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -1,5 +1,44 @@ // RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s +// CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32> +// CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32> + +// CHECK: func.func @func( +func.func @func(%arg0: f32) { + %c3 = arith.constant 3 : index + %c1 = arith.constant 1 : index + // CHECK: scf.forall + scf.forall (%arg1, %arg2) in (%c3, %c1) { + // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32> + // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32> + // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32> + // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32> + %alloca = memref.alloca() : memref<2x32xf32> + %alloca_0 = memref.alloca() : memref<2x32xf32> + memref.store %arg0, %alloca[%arg1, %arg2] : memref<2x32xf32> + memref.store %arg0, %alloca_0[%arg1, %arg2] : memref<2x32xf32> + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %module = transform.structured.match ops{["builtin.module"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %alloca_typed = transform.cast %alloca + : !transform.any_op to !transform.op<"memref.alloca"> + %module_typed = transform.cast %module + : !transform.any_op to !transform.op<"builtin.module"> + %get_global, %global = + transform.memref.alloca_to_global %alloca_typed in %module_typed + : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + -> (!transform.any_op, !transform.any_op) +} + +// ----- + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index 68e3a4851539690..91a283c799941bb 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1891,6 +1891,18 @@ transform.sequence failures(propagate) { test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op } + +// ----- + +// expected-note @below {{target op}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !pdl.operation): + // expected-error @below{{could not find a parent op that matches all requirements}} + %3 = get_parent_op %arg0 {op_name = "builtin.module"} : (!pdl.operation) -> !transform.any_op + } +} + // ----- func.func @cast(%arg0: f32) -> f64 { diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py index f89005cb2f86d1b..8278019bbab3b89 100644 --- a/mlir/test/python/dialects/transform_memref_ext.py +++ b/mlir/test/python/dialects/transform_memref_ext.py @@ -16,6 +16,54 @@ def run(f): return f +@run +def testMemRefAllocaToAllocOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("memref.alloc"), + ) + with InsertionPoint(sequence.body): + module = transform.CastOp( + transform.OperationType.get("builtin.module"), sequence.bodyTarget + ) + alloca = transform.CastOp( + transform.OperationType.get("memref.alloca"), sequence.bodyTarget + ) + memref.MemRefAllocaToGlobalOp(module, alloca) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact + # CHECK: = transform.memref.alloca_to_global + # CHECK-SAME: (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + # CHECK-SAME: -> (!transform.any_op, !transform.any_op) + + +@run +def testMemRefAllocaToAllocOpTyped(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("memref.alloc"), + ) + with InsertionPoint(sequence.body): + module = transform.CastOp( + transform.OperationType.get("builtin.module"), sequence.bodyTarget + ) + alloca = transform.CastOp( + transform.OperationType.get("memref.alloca"), sequence.bodyTarget + ) + memref.MemRefAllocaToGlobalOp( + transform.OperationType.get("memref.get_global"), + transform.OperationType.get("memref.global"), + module, + alloca, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped + # CHECK: = transform.memref.alloca_to_global + # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">) + + @run def testMemRefMultiBufferOpCompact(): sequence = transform.SequenceOp( |
I am not sure whether the I had an original version with that argument and that ran, except for some crashes that I retrospectively relate to #66357. It may, thus, work without the argument but is it legal to do so? |
eca2474
to
a9cbcf5
Compare
@ftynse: Can you help us with this? @nicolasvasilache and @matthias-springer seemed to say that this might be safe but weren't very confident in their assessment... |
mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
Outdated
Show resolved
Hide resolved
a9cbcf5
to
8ee31eb
Compare
mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
Outdated
Show resolved
Hide resolved
This PR adds a new transform op that replaces `memref.alloca`s with `memref.get_global`s to newly inserted `memref.global`s. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals.
In particular: * Accept any op type with `SymbolTable` trait as containing op rather than only `builtin.module` and rename op argument accordingly. * Use `SymbolTable::insert` to unique the name of the globals rather than some hand-rolled function. * Use more sane semantics in Python mix-in test.
8ee31eb
to
a91c93d
Compare
This PR adds a new transform op that replaces
memref.alloca
s withmemref.get_global
s to newly insertedmemref.global
s. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals.