Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][memref][transform] Add new alloca_to_global op. #66511

Merged
merged 4 commits into from
Sep 21, 2023

Conversation

ingomueller-net
Copy link
Contributor

This PR adds a new transform op that replaces memref.allocas with memref.get_globals to newly inserted memref.globals. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir mlir:memref labels Sep 15, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 15, 2023

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Changes This PR adds a new transform op that replaces `memref.alloca`s with `memref.get_global`s to newly inserted `memref.global`s. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals. -- Full diff: https://github.com//pull/66511.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td (+65)
  • (modified) mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp (+90)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+3-2)
  • (modified) mlir/python/mlir/dialects/_memref_transform_ops_ext.py (+58)
  • (modified) mlir/test/Dialect/MemRef/transform-ops.mlir (+39)
  • (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+12)
  • (modified) mlir/test/python/dialects/transform_memref_ext.py (+48)
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 681759f970cb910..6a78784d74dd53c 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -144,6 +144,71 @@ def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op<Transform_Dialect,
 }
 
 def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">;
+def Transform_MemRefAllocaOp : Transform_ConcreteOpType<"memref.alloca">;
+
+def MemRefAllocaToGlobalOp :
+  Op<Transform_Dialect, "memref.alloca_to_global",
+     [TransformOpInterface,
+      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Inserts a new `memref.global` for each provided `memref.alloca` into the
+    provided module and replaces it with a `memref.get_global`. This is useful,
+    for example, for allocations that should reside in the shared memory of
+    a GPU, which have to be declared as globals.
+
+    #### Example
+
+    Consider the following transform op:
+
+    ```mlir
+    %get_global, %global =
+        transform.memref.alloca_to_global %alloca in %module
+          : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+            -> (!transform.any_op, !transform.any_op)
+    ```
+
+    and the following input payload:
+
+    ```mlir
+    module {
+      func.func @func() {
+        %alloca = memref.alloca() : memref<2x32xf32>
+        // usages of %alloca...
+      }
+    }
+    ```
+
+    then applying the transform op to the payload would result in the following
+    output IR:
+
+    ```mlir
+    module {
+      memref.global "private" @alloc : memref<2x32xf32>
+      func.func @func() {
+        %alloca = memref.get_global @alloc : memref<2x32xf32>
+        // usages of %alloca...
+      }
+    }
+    ```
+
+    #### Return modes
+
+    Emits a definite failure if not exactly one `module` payload op was provided
+    or any of the `alloca` payload ops is not inside that module, and succeeds
+    otherwise. The returned handles refer to the `memref.get_global` and
+    `memref.global` ops that were inserted by the transformation.
+  }];
+
+  let arguments = (ins Transform_ConcreteOpType<"builtin.module">:$module,
+                   Transform_MemRefAllocaOp:$alloca);
+  let results = (outs TransformHandleTypeInterface:$get_global,
+                  TransformHandleTypeInterface:$global);
+
+  let assemblyFormat = [{
+    $alloca `in` $module attr-dict `:` functional-type(operands, results)
+  }];
+}
 
 def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 58f4d8d8f6d21fe..7467359da83c37f 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -126,6 +126,96 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
 }
 
+//===----------------------------------------------------------------------===//
+// AllocaToGlobalOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix,
+                                             ModuleOp module) {
+  llvm::SmallString<64> candidateNameStorage;
+  StringRef candidateName(prefix);
+  int uniqueNumber = 0;
+  while (true) {
+    if (!module.lookupSymbol(candidateName)) {
+      break;
+    }
+    candidateNameStorage.clear();
+    candidateName = (prefix + Twine("_") + Twine(uniqueNumber))
+                        .toStringRef(candidateNameStorage);
+    uniqueNumber++;
+  }
+  return candidateName;
+}
+} // namespace
+
+DiagnosedSilenceableFailure
+transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
+                                         transform::TransformResults &results,
+                                         transform::TransformState &state) {
+  auto allocaOps = state.getPayloadOps(getAlloca());
+
+  SmallVector<memref::GlobalOp> globalOps;
+  SmallVector<memref::GetGlobalOp> getGlobalOps;
+
+  // Get `builtin.module`.
+  auto moduleOps = state.getPayloadOps(getModule());
+  if (!llvm::hasSingleElement(moduleOps)) {
+    return emitDefiniteFailure()
+           << Twine("expected exactly one 'module' payload, but found ") +
+                  std::to_string(llvm::range_size(moduleOps));
+  }
+  ModuleOp module = cast<ModuleOp>(*moduleOps.begin());
+
+  // Transform `memref.alloca`s.
+  for (auto *op : allocaOps) {
+    auto alloca = cast<memref::AllocaOp>(op);
+    MLIRContext *ctx = rewriter.getContext();
+    Location loc = alloca->getLoc();
+
+    memref::GlobalOp globalOp;
+    {
+      // Insert a `memref.global` at the beginning of the module.
+      if (module != alloca->getParentOfType<ModuleOp>()) {
+        return emitDefiniteFailure()
+               << "expected 'alloca' payload to be inside 'module' payload";
+      }
+      IRRewriter::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&module.getBodyRegion().front());
+      Type resultType = alloca.getResult().getType();
+      llvm::SmallString<64> symName = getUniqueSymbol("alloca", module);
+      // XXX: Add a better builder for this.
+      globalOp = rewriter.create<memref::GlobalOp>(
+          loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"),
+          TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+    }
+
+    // Replace the `memref.alloca` with a `memref.get_global` accessing the
+    // global symbol inserted above.
+    rewriter.setInsertionPoint(alloca);
+    auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
+        alloca, globalOp.getType(), globalOp.getName());
+
+    globalOps.push_back(globalOp);
+    getGlobalOps.push_back(getGlobalOp);
+  }
+
+  // Assemble results.
+  results.set(getGlobal().cast<OpResult>(), globalOps);
+  results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MemRefAllocaToGlobalOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getModule(), effects);
+  producesHandle(getGlobal(), effects);
+  producesHandle(getGetGlobal(), effects);
+  consumesHandle(getAlloca(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefMultiBufferOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index de3cd1b28e435bc..f1d07b85adb7576 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1233,7 +1233,7 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
   DenseSet<Operation *> resultSet;
   for (Operation *target : state.getPayloadOps(getTarget())) {
     Operation *parent = target->getParentOp();
-    do {
+    while (parent) {
       bool checkIsolatedFromAbove =
           !getIsolatedFromAbove() ||
           parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
@@ -1241,7 +1241,8 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
                          parent->getName().getStringRef() == *getOpName();
       if (checkIsolatedFromAbove && checkOpName)
         break;
-    } while ((parent = parent->getParentOp()));
+      parent = parent->getParentOp();
+    }
     if (!parent) {
       DiagnosedSilenceableFailure diag =
           emitSilenceableError()
diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
index 4afe8e7b887f68e..56dcfbe5655e9b6 100644
--- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
@@ -11,6 +11,64 @@
 from typing import Optional, overload, Union
 
 
+class MemRefAllocaToGlobalOp:
+    """Specialization for MemRefAllocaToGlobalOp class."""
+
+    @overload
+    def __init__(
+        self,
+        get_global_type: Type,
+        global_type: Type,
+        module: Union[Operation, OpView, Value],
+        alloca: Union[Operation, OpView, Value],
+        *,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        module: Union[Operation, OpView, Value],
+        alloca: Union[Operation, OpView, Value],
+        *,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    def __init__(
+        self,
+        get_global_type_or_module: Union[Operation, OpView, Type, Value],
+        global_type_or_alloca: Union[Operation, OpView, Type, Value],
+        module_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        *,
+        loc=None,
+        ip=None
+    ):
+        if isinstance(get_global_type_or_module, Type):
+            get_global_type = get_global_type_or_module
+            global_type = global_type_or_alloca
+            module = module_or_none
+            alloca = alloca_or_none
+        else:
+            get_global_type = transform.AnyOpType.get()
+            global_type = transform.AnyOpType.get()
+            module = get_global_type_or_module
+            alloca = global_type_or_alloca
+
+        super().__init__(
+            get_global_type,
+            global_type,
+            module,
+            alloca,
+            loc=loc,
+            ip=ip,
+        )
+
+
 class MemRefMultiBufferOp:
     """Specialization for MemRefMultiBufferOp class."""
 
diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
index b19db447af1c28a..aeeb2a6b0abedc5 100644
--- a/mlir/test/Dialect/MemRef/transform-ops.mlir
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -1,5 +1,44 @@
 // RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s
 
+// CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32>
+// CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32>
+
+// CHECK: func.func @func(
+func.func @func(%arg0: f32) {
+  %c3 = arith.constant 3 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: scf.forall
+  scf.forall (%arg1, %arg2) in (%c3, %c1) {
+    // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32>
+    // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32>
+    // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
+    // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
+    %alloca = memref.alloca() : memref<2x32xf32>
+    %alloca_0 = memref.alloca() : memref<2x32xf32>
+    memref.store %arg0, %alloca[%arg1, %arg2] : memref<2x32xf32>
+    memref.store %arg0, %alloca_0[%arg1, %arg2] : memref<2x32xf32>
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0
+      : (!transform.any_op) -> !transform.any_op
+  %module = transform.structured.match ops{["builtin.module"]} in %arg0
+      : (!transform.any_op) -> !transform.any_op
+  %alloca_typed = transform.cast %alloca
+      : !transform.any_op to !transform.op<"memref.alloca">
+  %module_typed = transform.cast %module
+      : !transform.any_op to !transform.op<"builtin.module">
+  %get_global, %global =
+      transform.memref.alloca_to_global %alloca_typed in %module_typed
+        : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+          -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 68e3a4851539690..91a283c799941bb 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1891,6 +1891,18 @@ transform.sequence failures(propagate) {
   test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op
 }
 
+
+// -----
+
+// expected-note @below {{target op}}
+module {
+  transform.sequence  failures(propagate) {
+  ^bb0(%arg0: !pdl.operation):
+    // expected-error @below{{could not find a parent op that matches all requirements}}
+    %3 = get_parent_op %arg0 {op_name = "builtin.module"} : (!pdl.operation) -> !transform.any_op
+  }
+}
+
 // -----
 
 func.func @cast(%arg0: f32) -> f64 {
diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py
index f89005cb2f86d1b..8278019bbab3b89 100644
--- a/mlir/test/python/dialects/transform_memref_ext.py
+++ b/mlir/test/python/dialects/transform_memref_ext.py
@@ -16,6 +16,54 @@ def run(f):
     return f
 
 
+@run
+def testMemRefAllocaToAllocOpCompact():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("memref.alloc"),
+    )
+    with InsertionPoint(sequence.body):
+        module = transform.CastOp(
+            transform.OperationType.get("builtin.module"), sequence.bodyTarget
+        )
+        alloca = transform.CastOp(
+            transform.OperationType.get("memref.alloca"), sequence.bodyTarget
+        )
+        memref.MemRefAllocaToGlobalOp(module, alloca)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
+    # CHECK: = transform.memref.alloca_to_global
+    # CHECK-SAME: (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+    # CHECK-SAME: -> (!transform.any_op, !transform.any_op)
+
+
+@run
+def testMemRefAllocaToAllocOpTyped():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("memref.alloc"),
+    )
+    with InsertionPoint(sequence.body):
+        module = transform.CastOp(
+            transform.OperationType.get("builtin.module"), sequence.bodyTarget
+        )
+        alloca = transform.CastOp(
+            transform.OperationType.get("memref.alloca"), sequence.bodyTarget
+        )
+        memref.MemRefAllocaToGlobalOp(
+            transform.OperationType.get("memref.get_global"),
+            transform.OperationType.get("memref.global"),
+            module,
+            alloca,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped
+    # CHECK: = transform.memref.alloca_to_global
+    # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">)
+
+
 @run
 def testMemRefMultiBufferOpCompact():
     sequence = transform.SequenceOp(

@ingomueller-net
Copy link
Contributor Author

I am not sure whether the module argument is really necessary. I do need access to the surrounding module to insert the memref.globals, but I can do that through code (with getParentOfType). Also, I leave the module itself (and all other ops exact for the alloca inputs) intact and only add new ones.

I had an original version with that argument and that ran, except for some crashes that I retrospectively relate to #66357. It may, thus, work without the argument but is it legal to do so?

@ingomueller-net ingomueller-net force-pushed the transform-alloc-to-global branch 2 times, most recently from eca2474 to a9cbcf5 Compare September 15, 2023 15:01
@ingomueller-net
Copy link
Contributor Author

I am not sure whether the module argument is really necessary. I do need access to the surrounding module to insert the memref.globals, but I can do that through code (with getParentOfType). Also, I leave the module itself (and all other ops exact for the alloca inputs) intact and only add new ones.

I had an original version with that argument and that ran, except for some crashes that I retrospectively relate to #66357. It may, thus, work without the argument but is it legal to do so?

@ftynse: Can you help us with this? @nicolasvasilache and @matthias-springer seemed to say that this might be safe but weren't very confident in their assessment...

This PR adds a new transform op that replaces `memref.alloca`s with
`memref.get_global`s to newly inserted `memref.global`s. This is useful,
for example, for allocations that should reside in the shared memory of
a GPU, which have to be declared as globals.
In particular:
* Accept any op type with `SymbolTable` trait as containing op rather
  than only `builtin.module` and rename op argument accordingly.
* Use `SymbolTable::insert` to unique the name of the globals rather
  than some hand-rolled function.
* Use more sane semantics in Python mix-in test.
@ingomueller-net ingomueller-net force-pushed the transform-alloc-to-global branch from 8ee31eb to a91c93d Compare September 21, 2023 13:49
@ingomueller-net ingomueller-net merged commit 991cb14 into llvm:main Sep 21, 2023
@ingomueller-net ingomueller-net deleted the transform-alloc-to-global branch September 21, 2023 16:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:memref mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants