Skip to content

Commit

Permalink
Fix build issues
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavsharma committed Mar 13, 2020
1 parent dd7ed16 commit 8dec4b7
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 107 deletions.
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class Environment {
return logging_manager_.get();
}

void SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingManager> logging_manager) {
logging_manager_ = std::move(logging_manager);
}

onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const {
return intra_op_thread_pool_.get();
}
Expand Down
4 changes: 2 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ typedef enum OrtMemType {

typedef struct ThreadingOptions {
// number of threads used to parallelize execution of an op
int intra_op_num_threads = 0; // default value
int intra_op_num_threads; // use 0 if you want onnxruntime to choose a value for you

// number of threads used to parallelize execution across ops
int inter_op_num_threads = 0; // default value
int inter_op_num_threads; // use 0 if you want onnxruntime to choose a value for you
} ThreadingOptions;

struct OrtApi;
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options,
}

InferenceSession::InferenceSession(const SessionOptions& session_options,
const std::string& model_uri,
const Environment& session_env)
const Environment& session_env,
const std::string& model_uri)
: insert_cast_transformer_("CastFloat16Transformer") {
model_location_ = ToWideString(model_uri);
model_proto_ = onnxruntime::make_unique<ONNX_NAMESPACE::ModelProto>();
Expand All @@ -236,8 +236,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options,

#ifdef _WIN32
InferenceSession::InferenceSession(const SessionOptions& session_options,
const std::wstring& model_uri,
const Environment& session_env)
const Environment& session_env,
const std::wstring& model_uri)
: insert_cast_transformer_("CastFloat16Transformer") {
model_location_ = ToWideString(model_uri);
model_proto_ = onnxruntime::make_unique<ONNX_NAMESPACE::ModelProto>();
Expand All @@ -251,8 +251,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options,
#endif

InferenceSession::InferenceSession(const SessionOptions& session_options,
std::istream& model_istream,
const Environment& session_env)
const Environment& session_env,
std::istream& model_istream)
: insert_cast_transformer_("CastFloat16Transformer") {
google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream);
model_proto_ = onnxruntime::make_unique<ONNX_NAMESPACE::ModelProto>();
Expand All @@ -264,9 +264,9 @@ InferenceSession::InferenceSession(const SessionOptions& session_options,
}

InferenceSession::InferenceSession(const SessionOptions& session_options,
const Environment& session_env,
const void* model_data,
int model_data_len,
const Environment& session_env)
int model_data_len)
: insert_cast_transformer_("CastFloat16Transformer") {
model_proto_ = onnxruntime::make_unique<ONNX_NAMESPACE::ModelProto>();
const bool result = model_proto_->ParseFromArray(model_data, model_data_len);
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ class InferenceSession {
This ctor will throw on encountering model parsing issues.
*/
InferenceSession(const SessionOptions& session_options,
const std::string& model_uri,
const Environment& session_env);
const Environment& session_env,
const std::string& model_uri);
#ifdef _WIN32
InferenceSession(const SessionOptions& session_options,
const std::wstring& model_uri,
const Environment& session_env);
const Environment& session_env,
const std::wstring& model_uri);
#endif

/**
Expand All @@ -142,8 +142,8 @@ class InferenceSession {
This ctor will throw on encountering model parsing issues.
*/
InferenceSession(const SessionOptions& session_options,
std::istream& model_istream,
const Environment& session_env);
const Environment& session_env,
std::istream& model_istream);

/**
Create a new InferenceSession
Expand All @@ -159,9 +159,9 @@ class InferenceSession {
This ctor will throw on encountering model parsing issues.
*/
InferenceSession(const SessionOptions& session_options,
const Environment& session_env,
const void* model_data,
int model_data_len,
const Environment& session_env);
int model_data_len);

virtual ~InferenceSession();

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O
try {
sess = onnxruntime::make_unique<onnxruntime::InferenceSession>(
options == nullptr ? onnxruntime::SessionOptions() : options->value,
model_path, env->GetEnvironment());
env->GetEnvironment(), model_path);
} catch (const std::exception& e) {
return OrtApis::CreateStatus(ORT_FAIL, e.what());
}
Expand All @@ -439,7 +439,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
try {
sess = onnxruntime::make_unique<onnxruntime::InferenceSession>(
options == nullptr ? onnxruntime::SessionOptions() : options->value,
model_data, static_cast<int>(model_data_length), env->GetEnvironment());
env->GetEnvironment(), model_data, static_cast<int>(model_data_length));
} catch (const std::exception& e) {
return OrtApis::CreateStatus(ORT_FAIL, e.what());
}
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/session/ort_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,12 @@ void OrtEnv::Release(OrtEnv* env_ptr) {
delete p_instance_;
p_instance_ = nullptr;
}
}

onnxruntime::logging::LoggingManager* OrtEnv::GetLoggingManager() const {
return value_->GetLoggingManager();
}

void OrtEnv::SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingManager> logging_manager) {
value_->SetLoggingManager(std::move(logging_manager));
}
5 changes: 2 additions & 3 deletions onnxruntime/core/session/ort_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ struct OrtEnv {
return *(value_.get());
}

// onnxruntime::logging::LoggingManager* GetLoggingManager() const;

// void SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingManager> logging_manager);
onnxruntime::logging::LoggingManager* GetLoggingManager() const;
void SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingManager> logging_manager);

private:
static OrtEnv* p_instance_;
Expand Down
135 changes: 69 additions & 66 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,28 +234,29 @@ void AddTensorAsPyObj(OrtValue& val, std::vector<py::object>& pyobjs) {
GetPyObjFromTensor(rtensor, obj);
pyobjs.push_back(obj);
}

class SessionObjectInitializer {
public:
typedef const SessionOptions& Arg1;
typedef logging::LoggingManager* Arg2;
// typedef logging::LoggingManager* Arg2;
static std::string default_logger_id;
operator Arg1() {
return GetDefaultCPUSessionOptions();
}

operator Arg2() {
static std::string default_logger_id{"Default"};
static LoggingManager default_logging_manager{std::unique_ptr<ISink>{new CErrSink{}},
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&default_logger_id};
return &default_logging_manager;
}
// operator Arg2() {
// static LoggingManager default_logging_manager{std::unique_ptr<ISink>{new CErrSink{}},
// Severity::kWARNING, false, LoggingManager::InstanceType::Default,
// &default_logger_id};
// return &default_logging_manager;
// }

static SessionObjectInitializer Get() {
return SessionObjectInitializer();
}
};

std::string SessionObjectInitializer::default_logger_id = "Default";

inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) {
auto p = f.CreateProvider();
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(p)));
Expand Down Expand Up @@ -350,17 +351,17 @@ void InitializeSession(InferenceSession* sess, const std::vector<std::string>& p
OrtPybindThrowIfError(sess->Initialize());
}

void addGlobalMethods(py::module& m) {
void addGlobalMethods(py::module& m, const Environment& env) {
m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance.");
m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer.");
m.def(
"get_device", []() -> std::string { return BACKEND_DEVICE; },
"Return the device used to compute the prediction (CPU, MKL, ...)");
m.def(
"set_default_logger_severity", [](int severity) {
"set_default_logger_severity", [&env](int severity) {
ORT_ENFORCE(severity >= 0 && severity <= 4,
"Invalid logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal");
logging::LoggingManager* default_logging_manager = SessionObjectInitializer::Get();
logging::LoggingManager* default_logging_manager = env.GetLoggingManager();
default_logging_manager->SetDefaultLoggerSeverity(static_cast<logging::Severity>(severity));
},
"Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal");
Expand Down Expand Up @@ -546,7 +547,7 @@ void addOpSchemaSubmodule(py::module& m) {

#endif //onnxruntime_PYBIND_EXPORT_OPSCHEMA

void addObjectMethods(py::module& m) {
void addObjectMethods(py::module& m, Environment& env) {
py::enum_<GraphOptimizationLevel>(m, "GraphOptimizationLevel")
.value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL)
.value("ORT_ENABLE_BASIC", GraphOptimizationLevel::ORT_ENABLE_BASIC)
Expand Down Expand Up @@ -658,72 +659,70 @@ including arg name, arg type (contains both type and shape).)pbdoc")
return *(na.Type());
},
"node type")
.def(
"__str__", [](const onnxruntime::NodeArg& na) -> std::string {
std::ostringstream res;
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
res << "[]";
.def("__str__", [](const onnxruntime::NodeArg& na) -> std::string {
std::ostringstream res;
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
res << "[]";
} else {
res << "[";
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
res << shape->dim(i).dim_value();
} else if (utils::HasDimParam(shape->dim(i))) {
res << "'" << shape->dim(i).dim_param() << "'";
} else {
res << "[";
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
res << shape->dim(i).dim_value();
} else if (utils::HasDimParam(shape->dim(i))) {
res << "'" << shape->dim(i).dim_param() << "'";
} else {
res << "None";
}

if (i < shape->dim_size() - 1) {
res << ", ";
}
}
res << "]";
res << "None";
}
res << ")";

return std::string(res.str());
},
"converts the node into a readable string")
.def_property_readonly(
"shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
return arr;
if (i < shape->dim_size() - 1) {
res << ", ";
}
}
res << "]";
}
res << ")";

arr.resize(shape->dim_size());
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_value());
} else if (utils::HasDimParam(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_param());
} else {
arr[i] = py::none();
}
}
return arr;
},
"node shape (assuming the node holds a tensor)");
return std::string(res.str());
},
"converts the node into a readable string")
.def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
return arr;
}

arr.resize(shape->dim_size());
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_value());
} else if (utils::HasDimParam(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_param());
} else {
arr[i] = py::none();
}
}
return arr;
},
"node shape (assuming the node holds a tensor)");

py::class_<SessionObjectInitializer>(m, "SessionObjectInitializer");
py::class_<InferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
// In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char*
// without any conversion. So this init method can be used for model file path (string)
// and model content (bytes)
.def(py::init([](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) {
.def(py::init([&env](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) {
// Given arg is the file path. Invoke the corresponding ctor().
if (is_arg_file_name) {
return onnxruntime::make_unique<InferenceSession>(so, arg, SessionObjectInitializer::Get());
return onnxruntime::make_unique<InferenceSession>(so, env, arg);
}

// Given arg is the model content as bytes. Invoke the corresponding ctor().
std::istringstream buffer(arg);
return onnxruntime::make_unique<InferenceSession>(so, buffer, SessionObjectInitializer::Get());
return onnxruntime::make_unique<InferenceSession>(so, env, buffer);
}))
.def(
"load_model", [](InferenceSession* sess, std::vector<std::string>& provider_types) {
Expand Down Expand Up @@ -867,15 +866,19 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {

#endif

static std::unique_ptr<Environment> env;
auto initialize = [&]() {
// Initialization of the module
([]() -> void {
// import_array1() forces a void return value.
import_array1();
})();

static std::unique_ptr<Environment> env;
OrtPybindThrowIfError(Environment::Create(env));
OrtPybindThrowIfError(Environment::Create(std::make_unique<LoggingManager>(
std::unique_ptr<ISink>{new CErrSink{}},
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id),
env));

static bool initialized = false;
if (initialized) {
Expand All @@ -885,8 +888,8 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
};
initialize();

addGlobalMethods(m);
addObjectMethods(m);
addGlobalMethods(m, *env);
addObjectMethods(m, *env);

#ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
addOpSchemaSubmodule(m);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/framework/cuda/fence_cuda_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ typedef std::vector<onnxruntime::NodeArg*> ArgMap;

class FenceCudaTestInferenceSession : public InferenceSession {
public:
FenceCudaTestInferenceSession(const SessionOptions& so) : InferenceSession(so) {}
FenceCudaTestInferenceSession(const SessionOptions& so, const Environment& env) : InferenceSession(so, env) {}
Status LoadModel(onnxruntime::Model& model) {
auto model_proto = model.ToProto();
auto st = Load(model_proto);
Expand Down Expand Up @@ -117,7 +117,7 @@ TEST(CUDAFenceTests, DISABLED_PartOnCPU) {
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());

SessionOptions so;
FenceCudaTestInferenceSession session(so);
FenceCudaTestInferenceSession session(so, GetEnvironment());
LoadInferenceSessionFromModel(session, *model);
CUDAExecutionProviderInfo xp_info;
session.RegisterExecutionProvider(onnxruntime::make_unique<CUDAExecutionProvider>(xp_info));
Expand Down
Loading

0 comments on commit 8dec4b7

Please sign in to comment.