Skip to content

Commit

Permalink
ExecutionSession constructor without entry function name (llvm#1163)
Browse files Browse the repository at this point in the history
* ExecutionSession constructor without entry function name

Signed-off-by: Tung D. Le <[email protected]>

* clang-format

Signed-off-by: Tung D. Le <[email protected]>

* Do not use run_main_graph in ExecutionSession

Signed-off-by: Tung D. Le <[email protected]>

* Revise

Signed-off-by: Tung D. Le <[email protected]>

* Add omEntryPointName to pass Windows CI

Signed-off-by: Tung D. Le <[email protected]>

* Address review comments

Signed-off-by: Tung D. Le <[email protected]>

* Use run_main_graph as the default entry point

Signed-off-by: Tung D. Le <[email protected]>

* undo numerical.def

Signed-off-by: Tung D. Le <[email protected]>

* Change messages

Signed-off-by: Tung D. Le <[email protected]>

Co-authored-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
tungld and AlexandreEichenberger authored Feb 11, 2022
1 parent be32233 commit 52fb808
Show file tree
Hide file tree
Showing 17 changed files with 37 additions and 31 deletions.
6 changes: 2 additions & 4 deletions docs/UsingPyRuntime.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ The complete interface to ExecutionSession can be seen in the sources mentioned
using the constructor and run method is enough to perform inferences.

```python
def __init__(self, path: str, entry_point: str):
def __init__(self, path: str):
"""
Args:
path: relative or absolute path to your .so model.
entry_point: function generated by onnx-mlir to call inferences.
Use 'run_main_graph'.
"""

def run(self, input: List[ndarray]) -> List[ndarray]:
Expand Down Expand Up @@ -72,7 +70,7 @@ import numpy as np
from PyRuntime import ExecutionSession

model = 'model.so' # LeNet from ONNX Zoo compiled with onnx-mlir
session = ExecutionSession(model, "run_main_graph")
session = ExecutionSession(model)
print("input signature in json", session.input_signature())
print("output signature in json",session.output_signature())
input = np.full((1, 1, 28, 28), 1, np.dtype(np.float32))
Expand Down
2 changes: 1 addition & 1 deletion docs/mnist_example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ The runtime use an `ExecutionSession` object to hold a specific model and entry
``` Python
# Load the model mnist.so compiled with onnx-mlir.
model = 'mnist.so'
session = ExecutionSession(model, "run_main_graph")
session = ExecutionSession(model)
# Print the models input/output signature, for display.
# If there are problems with the signature functions, they can be simply commented out.
print("input signature in json", session.input_signature())
Expand Down
2 changes: 1 addition & 1 deletion docs/mnist_example/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# Load the model mnist.so compiled with onnx-mlir.
model = './mnist.so'
session = ExecutionSession(model, "run_main_graph")
session = ExecutionSession(model)
# Print the models input/output signature, for display.
# Signature functions for info only, commented out if they cause problems.
print("input signature in json", session.input_signature())
Expand Down
13 changes: 9 additions & 4 deletions src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1194,14 +1194,19 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
KrnlEntryPointOp::getEntryPointFuncAttrName())
.getLeafReference()
.getValue();
auto dynEntryPointName = "run_" + staticEntryPointFuncName;
assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr &&
"dynamic entry point name is not unique");

// When there is only a single entry point function in a model, use
// "run_main_graph" as the default name.
// TODO(tung): support multiple entry point functions.
std::string entryPointName = "run_main_graph";
assert(module.lookupSymbol(entryPointName) == nullptr &&
"Only support a single entry point function.");

rewriter.eraseOp(op);
auto dynEntryPointFuncTy =
LLVM::LLVMFunctionType::get(opaquePtrTy, {opaquePtrTy}, false);
auto dynamicEntryPointFunc = rewriter.create<LLVM::LLVMFuncOp>(
loc, dynEntryPointName.str(), dynEntryPointFuncTy);
loc, entryPointName, dynEntryPointFuncTy);
auto &entryPointEntryBlock =
createEntryBlock(dynEntryPointFuncTy, dynamicEntryPointFunc, loc);
rewriter.setInsertionPointToStart(&entryPointEntryBlock);
Expand Down
8 changes: 8 additions & 0 deletions src/Runtime/ExecutionSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ namespace onnx_mlir {
const std::string ExecutionSession::_inputSignatureName = "omInputSignature";
const std::string ExecutionSession::_outputSignatureName = "omOutputSignature";

ExecutionSession::ExecutionSession(std::string sharedLibPath)
: ExecutionSession::ExecutionSession(sharedLibPath, "") {}

ExecutionSession::ExecutionSession(
std::string sharedLibPath, std::string entryPointName) {

Expand All @@ -36,6 +39,11 @@ ExecutionSession::ExecutionSession(
throw std::runtime_error(errStr.str());
}

// When entry point name is not given, use the default "run_main_graph".
// TODO(tung): support multiple entry point functions.
if (entryPointName.empty())
entryPointName = "run_main_graph";

_entryPointFunc = reinterpret_cast<entryPointFuncType>(
_sharedLibraryHandle.getAddressOfSymbol(entryPointName.c_str()));
if (!_entryPointFunc) {
Expand Down
1 change: 1 addition & 0 deletions src/Runtime/ExecutionSession.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using OMTensorUniquePtr = std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>;

class ExecutionSession {
public:
ExecutionSession(std::string sharedLibPath);
ExecutionSession(std::string sharedLibPath, std::string entryPointName);

// Use custom deleter since forward declared OMTensor hides destructor
Expand Down
5 changes: 4 additions & 1 deletion src/Runtime/PyExecutionSession.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ namespace onnx_mlir {

class PyExecutionSession : public onnx_mlir::ExecutionSession {
public:
PyExecutionSession(std::string sharedLibPath)
: onnx_mlir::ExecutionSession(sharedLibPath) {}
PyExecutionSession(std::string sharedLibPath, std::string entryPointName)
: onnx_mlir::ExecutionSession(sharedLibPath, entryPointName){};
: onnx_mlir::ExecutionSession(sharedLibPath, entryPointName) {}

std::vector<py::array> pyRun(const std::vector<py::array> &inputsPyArray);

Expand All @@ -37,6 +39,7 @@ class PyExecutionSession : public onnx_mlir::ExecutionSession {

PYBIND11_MODULE(PyRuntime, m) {
py::class_<onnx_mlir::PyExecutionSession>(m, "ExecutionSession")
.def(py::init<const std::string &>())
.def(py::init<const std::string &, const std::string &>())
.def("run", &onnx_mlir::PyExecutionSession::pyRun)
.def("input_signature", &onnx_mlir::PyExecutionSession::pyInputSignature)
Expand Down
5 changes: 2 additions & 3 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,6 @@ def JniExecutionSession(jar_name, inputs):
class EndiannessAwareExecutionSession(object):
def __init__(self, model):
self.model = model
self.entry_point = "run_main_graph"
self.exec_name = None
# Compiling the model in advance if not testing constants, so that
# the model is compiled once and used multiple times.
Expand Down Expand Up @@ -1060,7 +1059,7 @@ def run(self, inputs, **kwargs):
inputs = self.turn_model_input_to_constant(inputs)
self.exec_name = compile_model(self.model, args.emit)
if args.emit == "lib":
session = ExecutionSession(self.exec_name, self.entry_point)
session = ExecutionSession(self.exec_name)
outputs = session.run(inputs)
# print('input='+str(inputs), file=sys.stderr)
# print('output='+str(outputs), file=sys.stderr)
Expand All @@ -1079,7 +1078,7 @@ def run(self, inputs, **kwargs):
"Cannot deduce desired output endianness, using native endianness by default."
)
if args.emit == "lib":
session = ExecutionSession(self.exec_name, self.entry_point)
session = ExecutionSession(self.exec_name)
outputs = session.run(inputs)
elif args.emit == "jni":
outputs = JniExecutionSession(self.exec_name, inputs)
Expand Down
3 changes: 1 addition & 2 deletions test/backend/signature_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,13 @@ def run(test_self, device): # type: (Any, Text) -> None
class SignatureExecutionSession(object):
def __init__(self, model):
self.model = model
self.entry_point = "run_main_graph"
self.exec_name = compile_model(self.model, args.emit)

def run(self, **kwargs):
sys.path.append(RUNTIME_DIR)
from PyRuntime import ExecutionSession

session = ExecutionSession(self.exec_name, self.entry_point)
session = ExecutionSession(self.exec_name)
output = session.input_signature()
return output

Expand Down
3 changes: 1 addition & 2 deletions test/numerical/TestConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
/*output conv param*/ NOut, COut, HOut, WOut))
return false;

onnx_mlir::ExecutionSession sess(
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));

std::vector<OMTensorUniquePtr> inputs;
auto xOmt = OMTensorUniquePtr(
Expand Down
3 changes: 1 addition & 2 deletions test/numerical/TestGRU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ bool isOMGRUTheSameAsNaiveImplFor(const int direction, const int S, const int B,
/* GRU param out*/
D, xShape, hShape, wOmt, rOmt, bOmt))
return false;
onnx_mlir::ExecutionSession sess(
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));

std::vector<OMTensorUniquePtr> inputs;
auto xOmt = OMTensorUniquePtr(
Expand Down
3 changes: 1 addition & 2 deletions test/numerical/TestGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ bool isOMGemmTheSameAsNaiveImplFor(const int I, const int J, const int K,
aShape, bShape, cShape))
return false;

onnx_mlir::ExecutionSession sess(
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));

std::vector<OMTensorUniquePtr> inputs;
auto aOmt = OMTensorUniquePtr(
Expand Down
3 changes: 1 addition & 2 deletions test/numerical/TestLSTM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ bool isOMLSTMTheSameAsNaiveImplFor(const int direction, const int S,
D, xShape, hShape, cShape, wOmt, rOmt, bOmt, pOmt))
return false;

onnx_mlir::ExecutionSession sess(
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));

std::vector<OMTensorUniquePtr> inputs;
auto xOmt = OMTensorUniquePtr(
Expand Down
3 changes: 1 addition & 2 deletions test/numerical/TestLoop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ bool isOMLoopTheSameAsNaiveImplFor(std::string moduleIR,
auto module = mlir::parseSourceString(moduleIR, &ctx);
OwningModuleRef moduleRef(std::move(module));
compileModule(moduleRef, ctx, SHARED_LIB_BASE.str(), onnx_mlir::EmitLib);
onnx_mlir::ExecutionSession sess(
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));

std::vector<OMTensorUniquePtr> inputs;
auto tripCountTensor = OMTensorUniquePtr(
Expand Down
3 changes: 1 addition & 2 deletions test/numerical/TestMatMul2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ bool isOMMatmulTheSameAsNaiveImplFor(const int I, const int J, const int K) {
I, J, K))
return false;

onnx_mlir::ExecutionSession sess(
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));

std::vector<OMTensorUniquePtr> inputs;
auto aOmt = OMTensorUniquePtr(
Expand Down
3 changes: 1 addition & 2 deletions test/numerical/TestRNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ bool isOMRNNTheSameAsNaiveImplFor(const int direction, const int S, const int B,
D, xShape, hShape, wOmt, rOmt, bOmt))
return false;

onnx_mlir::ExecutionSession sess(
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));

std::vector<OMTensorUniquePtr> inputs;
auto xOmt = OMTensorUniquePtr(
Expand Down
2 changes: 1 addition & 1 deletion utils/RunONNXModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def main():
# Use the generated shared library to create an execution session.
print("Loading the compiled model ...")
start = time.perf_counter()
sess = ExecutionSession(shared_lib_path, "run_main_graph")
sess = ExecutionSession(shared_lib_path)
end = time.perf_counter()
print(" took ", end - start, " seconds.\n")

Expand Down

0 comments on commit 52fb808

Please sign in to comment.