Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][python] Expose AsmState python side. #66819

Merged
merged 3 commits into from
Sep 20, 2023
Merged

Conversation

jpienaar
Copy link
Member

This does basic plumbing, ideally want a context approach to reduce needing to thread these manually, but the current is useful even in that state.

Made Value.get_name change backwards compatible, so one could either set a field or create a state to pass in.

This does basic plumbing, ideally want a context approach to reduce
needing to thread these manually, but the current is useful even in that
state.
@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2023

@llvm/pr-subscribers-mlir

Changes

This does basic plumbing, ideally want a context approach to reduce needing to thread these manually, but the current is useful even in that state.

Made Value.get_name change backwards compatible, so one could either set a field or create a state to pass in.


Full diff: https://github.com/llvm/llvm-project/pull/66819.diff

3 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+29-9)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+25)
  • (modified) mlir/test/python/ir/value.py (+11-1)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index af713547cccbb27..2ab1219016006d8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3425,19 +3425,35 @@ void mlir::python::populateIRCore(py::module &m) {
           kValueDunderStrDocstring)
       .def(
           "get_name",
-          [](PyValue &self, bool useLocalScope) {
+          [](PyValue &self, std::optional<bool> useLocalScope,
+             std::optional<std::reference_wrapper<PyAsmState>> state) {
             PyPrintAccumulator printAccum;
-            MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
-            if (useLocalScope)
-              mlirOpPrintingFlagsUseLocalScope(flags);
-            MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags);
-            mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(),
+            MlirOpPrintingFlags flags;
+            MlirAsmState valueState;
+            // Use state if provided, else create a new state.
+            if (state) {
+              valueState = state.value().get().get();
+              // Don't allow setting using local scope and state at same time.
+              if (useLocalScope)
+                throw py::value_error(
+                    "setting AsmState and local scope together not supported");
+            } else {
+              flags = mlirOpPrintingFlagsCreate();
+              if (useLocalScope.value_or(false))
+                mlirOpPrintingFlagsUseLocalScope(flags);
+              valueState = mlirAsmStateCreateForValue(self.get(), flags);
+            }
+            mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(),
                                     printAccum.getUserData());
-            mlirOpPrintingFlagsDestroy(flags);
-            mlirAsmStateDestroy(state);
+            // Release state if allocated locally.
+            if (!state) {
+              mlirOpPrintingFlagsDestroy(flags);
+              mlirAsmStateDestroy(valueState);
+            }
             return printAccum.join();
           },
-          py::arg("use_local_scope") = false, kGetNameAsOperand)
+          py::arg("use_local_scope") = std::nullopt,
+          py::arg("state") = std::nullopt, kGetNameAsOperand)
       .def_property_readonly(
           "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
       .def(
@@ -3456,6 +3472,10 @@ void mlir::python::populateIRCore(py::module &m) {
   PyOpResult::bind(m);
   PyOpOperand::bind(m);
 
+  py::class_<PyAsmState>(m, "AsmState", py::module_local())
+      .def(py::init<PyValue &, bool>(), py::arg("value"),
+           py::arg("use_local_scope"));
+
   //----------------------------------------------------------------------------
   // Mapping of SymbolTable.
   //----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index d1911730c1ede03..23338f7fdb38add 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -748,6 +748,31 @@ class PyRegion {
   MlirRegion region;
 };
 
+/// Wrapper around an MlirAsmState.
+class PyAsmState {
+ public:
+  PyAsmState(MlirValue value, bool useLocalScope) {
+    flags = mlirOpPrintingFlagsCreate();
+    // The OpPrintingFlags are not exposed Python side, create locally and
+    // associate lifetime with the state.
+    if (useLocalScope)
+      mlirOpPrintingFlagsUseLocalScope(flags);
+    state = mlirAsmStateCreateForValue(value, flags);
+  }
+  ~PyAsmState() {
+    mlirOpPrintingFlagsDestroy(flags);
+  }
+  // Delete copy constructors.
+  PyAsmState(PyAsmState &other) = delete;
+  PyAsmState(const PyAsmState &other) = delete;
+
+  MlirAsmState get() { return state; }
+
+ private:
+  MlirAsmState state;
+  MlirOpPrintingFlags flags;
+};
+
 /// Wrapper around an MlirBlock.
 /// Blocks are managed completely by their containing operation. Unlike the
 /// C++ API, the python API does not support detached blocks.
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 46a50ac5291e8d9..2a47c8d820eaf4f 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -1,4 +1,4 @@
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false
 
 import gc
 from mlir.ir import *
@@ -199,6 +199,16 @@ def testValuePrintAsOperand():
         # CHECK: %[[VAL4]]
         print(value4.get_name())
 
+        print("With AsmState")
+        # CHECK-LABEL: With AsmState
+        state = AsmState(value3, use_local_scope=True)
+        # CHECK: %0
+        print(value3.get_name(state=state))
+        # CHECK: %1
+        print(value4.get_name(state=state))
+
+        print("With use_local_scope")
+        # CHECK-LABEL: With use_local_scope
         # CHECK: %0
         print(value3.get_name(use_local_scope=True))
         # CHECK: %1

mlir/lib/Bindings/Python/IRCore.cpp Show resolved Hide resolved
mlir/lib/Bindings/Python/IRCore.cpp Outdated Show resolved Hide resolved
mlir/lib/Bindings/Python/IRModule.h Show resolved Hide resolved
@jpienaar jpienaar merged commit 7545371 into llvm:main Sep 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants