Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][python] Expose AsmState python side. #66819

Merged
merged 3 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3425,19 +3425,35 @@ void mlir::python::populateIRCore(py::module &m) {
kValueDunderStrDocstring)
.def(
"get_name",
[](PyValue &self, bool useLocalScope) {
[](PyValue &self, std::optional<bool> useLocalScope,
std::optional<std::reference_wrapper<PyAsmState>> state) {
PyPrintAccumulator printAccum;
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (useLocalScope)
mlirOpPrintingFlagsUseLocalScope(flags);
MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags);
mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(),
MlirOpPrintingFlags flags;
MlirAsmState valueState;
// Use state if provided, else create a new state.
if (state) {
valueState = state.value().get().get();
// Don't allow setting using local scope and state at same time.
if (useLocalScope)
jpienaar marked this conversation as resolved.
Show resolved Hide resolved
throw py::value_error(
"setting AsmState and local scope together not supported");
} else {
flags = mlirOpPrintingFlagsCreate();
if (useLocalScope.value_or(false))
mlirOpPrintingFlagsUseLocalScope(flags);
valueState = mlirAsmStateCreateForValue(self.get(), flags);
}
mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(),
printAccum.getUserData());
mlirOpPrintingFlagsDestroy(flags);
mlirAsmStateDestroy(state);
// Release state if allocated locally.
if (!state) {
mlirOpPrintingFlagsDestroy(flags);
mlirAsmStateDestroy(valueState);
}
return printAccum.join();
},
py::arg("use_local_scope") = false, kGetNameAsOperand)
py::arg("use_local_scope") = std::nullopt,
py::arg("state") = std::nullopt, kGetNameAsOperand)
jpienaar marked this conversation as resolved.
Show resolved Hide resolved
.def_property_readonly(
"type", [](PyValue &self) { return mlirValueGetType(self.get()); })
.def(
Expand All @@ -3456,6 +3472,10 @@ void mlir::python::populateIRCore(py::module &m) {
PyOpResult::bind(m);
PyOpOperand::bind(m);

py::class_<PyAsmState>(m, "AsmState", py::module_local())
.def(py::init<PyValue &, bool>(), py::arg("value"),
py::arg("use_local_scope") = false);

//----------------------------------------------------------------------------
// Mapping of SymbolTable.
//----------------------------------------------------------------------------
Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,31 @@ class PyRegion {
MlirRegion region;
};

/// Wrapper around an MlirAsmState.
class PyAsmState {
public:
PyAsmState(MlirValue value, bool useLocalScope) {
flags = mlirOpPrintingFlagsCreate();
// The OpPrintingFlags are not exposed Python side, create locally and
// associate lifetime with the state.
if (useLocalScope)
mlirOpPrintingFlagsUseLocalScope(flags);
state = mlirAsmStateCreateForValue(value, flags);
}
~PyAsmState() {
mlirOpPrintingFlagsDestroy(flags);
}
// Delete copy constructors.
PyAsmState(PyAsmState &other) = delete;
PyAsmState(const PyAsmState &other) = delete;
jpienaar marked this conversation as resolved.
Show resolved Hide resolved

MlirAsmState get() { return state; }

private:
MlirAsmState state;
MlirOpPrintingFlags flags;
};

/// Wrapper around an MlirBlock.
/// Blocks are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached blocks.
Expand Down
12 changes: 11 additions & 1 deletion mlir/test/python/ir/value.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# RUN: %PYTHON %s | FileCheck %s
# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false

import gc
from mlir.ir import *
Expand Down Expand Up @@ -199,6 +199,16 @@ def testValuePrintAsOperand():
# CHECK: %[[VAL4]]
print(value4.get_name())

print("With AsmState")
# CHECK-LABEL: With AsmState
state = AsmState(value3, use_local_scope=True)
# CHECK: %0
print(value3.get_name(state=state))
# CHECK: %1
print(value4.get_name(state=state))

print("With use_local_scope")
# CHECK-LABEL: With use_local_scope
# CHECK: %0
print(value3.get_name(use_local_scope=True))
# CHECK: %1
Expand Down