diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index af713547cccbb2..fc80e193b1aac7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3425,19 +3425,35 @@ void mlir::python::populateIRCore(py::module &m) { kValueDunderStrDocstring) .def( "get_name", - [](PyValue &self, bool useLocalScope) { + [](PyValue &self, std::optional useLocalScope, + std::optional> state) { PyPrintAccumulator printAccum; - MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - if (useLocalScope) - mlirOpPrintingFlagsUseLocalScope(flags); - MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags); - mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(), + MlirOpPrintingFlags flags; + MlirAsmState valueState; + // Use state if provided, else create a new state. + if (state) { + valueState = state.value().get().get(); + // Don't allow setting using local scope and state at same time. + if (useLocalScope.has_value()) + throw py::value_error( + "setting AsmState and local scope together not supported"); + } else { + flags = mlirOpPrintingFlagsCreate(); + if (useLocalScope.value_or(false)) + mlirOpPrintingFlagsUseLocalScope(flags); + valueState = mlirAsmStateCreateForValue(self.get(), flags); + } + mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(), printAccum.getUserData()); - mlirOpPrintingFlagsDestroy(flags); - mlirAsmStateDestroy(state); + // Release state if allocated locally. + if (!state) { + mlirOpPrintingFlagsDestroy(flags); + mlirAsmStateDestroy(valueState); + } return printAccum.join(); }, - py::arg("use_local_scope") = false, kGetNameAsOperand) + py::arg("use_local_scope") = std::nullopt, + py::arg("state") = std::nullopt, kGetNameAsOperand) .def_property_readonly( "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( @@ -3456,6 +3472,10 @@ void mlir::python::populateIRCore(py::module &m) { PyOpResult::bind(m); PyOpOperand::bind(m); + py::class_(m, "AsmState", py::module_local()) + .def(py::init(), py::arg("value"), + py::arg("use_local_scope") = false); + //---------------------------------------------------------------------------- // Mapping of SymbolTable. //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index d1911730c1ede0..23338f7fdb38ad 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -748,6 +748,31 @@ class PyRegion { MlirRegion region; }; +/// Wrapper around an MlirAsmState. +class PyAsmState { + public: + PyAsmState(MlirValue value, bool useLocalScope) { + flags = mlirOpPrintingFlagsCreate(); + // The OpPrintingFlags are not exposed Python side, create locally and + // associate lifetime with the state. + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + state = mlirAsmStateCreateForValue(value, flags); + } + ~PyAsmState() { + mlirOpPrintingFlagsDestroy(flags); + } + // Delete copy constructors. + PyAsmState(PyAsmState &other) = delete; + PyAsmState(const PyAsmState &other) = delete; + + MlirAsmState get() { return state; } + + private: + MlirAsmState state; + MlirOpPrintingFlags flags; +}; + /// Wrapper around an MlirBlock. /// Blocks are managed completely by their containing operation. Unlike the /// C++ API, the python API does not support detached blocks. diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index 46a50ac5291e8d..2a47c8d820eaf4 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false import gc from mlir.ir import * @@ -199,6 +199,16 @@ def testValuePrintAsOperand(): # CHECK: %[[VAL4]] print(value4.get_name()) + print("With AsmState") + # CHECK-LABEL: With AsmState + state = AsmState(value3, use_local_scope=True) + # CHECK: %0 + print(value3.get_name(state=state)) + # CHECK: %1 + print(value4.get_name(state=state)) + + print("With use_local_scope") + # CHECK-LABEL: With use_local_scope # CHECK: %0 print(value3.get_name(use_local_scope=True)) # CHECK: %1