Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to set CoreML EP flags from python #21434

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 38 additions & 32 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
#include "contrib_ops/cpu/aten_ops/aten_op_executor.h"
#endif

#if defined(USE_COREML)
#include "core/providers/coreml/coreml_provider_factory.h"
#endif

#include <pybind11/functional.h>

// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
Expand Down Expand Up @@ -1156,7 +1160,30 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
#if !defined(__APPLE__)
LOGS_DEFAULT(WARNING) << "CoreML execution provider can only be used to generate ORT format model in this build.";
#endif
return onnxruntime::CoreMLProviderFactoryCreator::Create(0)->CreateProvider();
uint32_t coreml_flags = 0;

const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
const ProviderOptions& options = it->second;
auto flags = options.find("flags");
if (flags != options.end()) {
const auto& flags_str = flags->second;

if (flags_str.find("COREML_FLAG_USE_CPU_ONLY") != std::string::npos) {
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_ONLY;
}

if (flags_str.find("COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES;
}

if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM;
}
}
}

return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider();
#endif
} else if (type == kXnnpackExecutionProvider) {
#if defined(USE_XNNPACK)
Expand Down Expand Up @@ -1887,9 +1914,7 @@ including arg name, arg type (contains both type and shape).)pbdoc")
}
res << ")";

return std::string(res.str());
},
"converts the node into a readable string")
return std::string(res.str()); }, "converts the node into a readable string")
.def_property_readonly(
"shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
auto shape = na.Shape();
Expand All @@ -1908,9 +1933,7 @@ including arg name, arg type (contains both type and shape).)pbdoc")
arr[i] = py::none();
}
}
return arr;
},
"node shape (assuming the node holds a tensor)");
return arr; }, "node shape (assuming the node holds a tensor)");

py::class_<SessionObjectInitializer> sessionObjectInitializer(m, "SessionObjectInitializer");
py::class_<PyInferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
Expand Down Expand Up @@ -2102,50 +2125,34 @@ including arg name, arg type (contains both type and shape).)pbdoc")
return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs();
})
.def(
"get_providers", [](const PyInferenceSession* sess) -> const std::vector<std::string>& {
return sess->GetSessionHandle()->GetRegisteredProviderTypes();
},
py::return_value_policy::reference_internal)
"get_providers", [](const PyInferenceSession* sess) -> const std::vector<std::string>& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal)
.def(
"get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& {
return sess->GetSessionHandle()->GetAllProviderOptions();
},
py::return_value_policy::reference_internal)
"get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal)
.def_property_readonly(
"session_options", [](const PyInferenceSession* sess) -> PySessionOptions* {
auto session_options = std::make_unique<PySessionOptions>();
session_options->value = sess->GetSessionHandle()->GetSessionOptions();
return session_options.release();
},
py::return_value_policy::take_ownership)
return session_options.release(); }, py::return_value_policy::take_ownership)
.def_property_readonly(
"inputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetModelInputs();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly(
"outputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetModelOutputs();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly(
"overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetOverridableInitializers();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly(
"model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& {
auto res = sess->GetSessionHandle()->GetModelMetadata();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
return *(res.second); }, py::return_value_policy::reference_internal)
.def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void {
Status status;
// release GIL to allow multiple python threads to invoke Run() in parallel.
Expand All @@ -2155,8 +2162,7 @@ including arg name, arg type (contains both type and shape).)pbdoc")
else
status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get());
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
})
throw std::runtime_error("Error in execution: " + status.ErrorMessage()); })
.def("get_tuning_results", [](PyInferenceSession* sess) -> py::list {
#if !defined(ORT_MINIMAL_BUILD)
auto results = sess->GetSessionHandle()->GetTuningResults();
Expand Down
Loading