From 8dec4b7c01992f23913be856c75a303a726a10c3 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Thu, 12 Mar 2020 18:50:51 -0700 Subject: [PATCH] Fix build issues --- .../onnxruntime/core/session/environment.h | 4 + .../core/session/onnxruntime_c_api.h | 4 +- onnxruntime/core/session/inference_session.cc | 16 +-- onnxruntime/core/session/inference_session.h | 16 +-- onnxruntime/core/session/onnxruntime_c_api.cc | 4 +- onnxruntime/core/session/ort_env.cc | 8 ++ onnxruntime/core/session/ort_env.h | 5 +- .../python/onnxruntime_pybind_state.cc | 135 +++++++++--------- .../test/framework/cuda/fence_cuda_test.cc | 4 +- .../test/framework/inference_session_test.cc | 16 +-- .../providers/tensorrt/tensorrt_basic_test.cc | 7 +- winml/adapter/winml_adapter_environment.cpp | 2 +- winml/adapter/winml_adapter_session.cpp | 6 +- 13 files changed, 120 insertions(+), 107 deletions(-) diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 632c03ac9b053..12d4e931188f2 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -30,6 +30,10 @@ class Environment { return logging_manager_.get(); } + void SetLoggingManager(std::unique_ptr logging_manager) { + logging_manager_ = std::move(logging_manager); + } + onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const { return intra_op_thread_pool_.get(); } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6a186f05ff3eb..61860a0c68c57 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -214,10 +214,10 @@ typedef enum OrtMemType { typedef struct ThreadingOptions { // number of threads used to parallelize execution of an op - int intra_op_num_threads = 0; // default value + int intra_op_num_threads; // use 0 if you want onnxruntime to choose a value for you // number of threads used to parallelize execution across ops - int inter_op_num_threads = 0; // default value + int inter_op_num_threads; // use 0 if you want onnxruntime to choose a value for you } ThreadingOptions; struct OrtApi; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index be98c23100f5d..bed40d404f272 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -221,8 +221,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, } InferenceSession::InferenceSession(const SessionOptions& session_options, - const std::string& model_uri, - const Environment& session_env) + const Environment& session_env, + const std::string& model_uri) : insert_cast_transformer_("CastFloat16Transformer") { model_location_ = ToWideString(model_uri); model_proto_ = onnxruntime::make_unique(); @@ -236,8 +236,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, #ifdef _WIN32 InferenceSession::InferenceSession(const SessionOptions& session_options, - const std::wstring& model_uri, - const Environment& session_env) + const Environment& session_env, + const std::wstring& model_uri) : insert_cast_transformer_("CastFloat16Transformer") { model_location_ = ToWideString(model_uri); model_proto_ = onnxruntime::make_unique(); @@ -251,8 +251,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, #endif InferenceSession::InferenceSession(const SessionOptions& session_options, - std::istream& model_istream, - const Environment& session_env) + const Environment& session_env, + std::istream& model_istream) : insert_cast_transformer_("CastFloat16Transformer") { google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream); model_proto_ = onnxruntime::make_unique(); @@ -264,9 +264,9 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, } InferenceSession::InferenceSession(const SessionOptions& session_options, + const Environment& session_env, const void* model_data, - int model_data_len, - const Environment& session_env) + int model_data_len) : insert_cast_transformer_("CastFloat16Transformer") { model_proto_ = onnxruntime::make_unique(); const bool result = model_proto_->ParseFromArray(model_data, model_data_len); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 93c1e1b26bf59..ec23697b4ce49 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -121,12 +121,12 @@ class InferenceSession { This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, - const std::string& model_uri, - const Environment& session_env); + const Environment& session_env, + const std::string& model_uri); #ifdef _WIN32 InferenceSession(const SessionOptions& session_options, - const std::wstring& model_uri, - const Environment& session_env); + const Environment& session_env, + const std::wstring& model_uri); #endif /** @@ -142,8 +142,8 @@ class InferenceSession { This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, - std::istream& model_istream, - const Environment& session_env); + const Environment& session_env, + std::istream& model_istream); /** Create a new InferenceSession @@ -159,9 +159,9 @@ class InferenceSession { This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, + const Environment& session_env, const void* model_data, - int model_data_len, - const Environment& session_env); + int model_data_len); virtual ~InferenceSession(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 03dc991ffae58..d80d41e1265f5 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -424,7 +424,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O try { sess = onnxruntime::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - model_path, env->GetEnvironment()); + env->GetEnvironment(), model_path); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } @@ -439,7 +439,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In try { sess = onnxruntime::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - model_data, static_cast(model_data_length), env->GetEnvironment()); + env->GetEnvironment(), model_data, static_cast(model_data_length)); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index 662cf47bce7d0..2b027404cd3b1 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -82,4 +82,12 @@ void OrtEnv::Release(OrtEnv* env_ptr) { delete p_instance_; p_instance_ = nullptr; } +} + +onnxruntime::logging::LoggingManager* OrtEnv::GetLoggingManager() const { + return value_->GetLoggingManager(); +} + +void OrtEnv::SetLoggingManager(std::unique_ptr logging_manager) { + value_->SetLoggingManager(std::move(logging_manager)); } \ No newline at end of file diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index c383639cc814b..81055874a82b6 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -52,9 +52,8 @@ struct OrtEnv { return *(value_.get()); } - // onnxruntime::logging::LoggingManager* GetLoggingManager() const; - - // void SetLoggingManager(std::unique_ptr logging_manager); + onnxruntime::logging::LoggingManager* GetLoggingManager() const; + void SetLoggingManager(std::unique_ptr logging_manager); private: static OrtEnv* p_instance_; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d4baca48da303..ead3ba637a5a1 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -234,28 +234,29 @@ void AddTensorAsPyObj(OrtValue& val, std::vector& pyobjs) { GetPyObjFromTensor(rtensor, obj); pyobjs.push_back(obj); } - class SessionObjectInitializer { public: typedef const SessionOptions& Arg1; - typedef logging::LoggingManager* Arg2; + // typedef logging::LoggingManager* Arg2; + static std::string default_logger_id; operator Arg1() { return GetDefaultCPUSessionOptions(); } - operator Arg2() { - static std::string default_logger_id{"Default"}; - static LoggingManager default_logging_manager{std::unique_ptr{new CErrSink{}}, - Severity::kWARNING, false, LoggingManager::InstanceType::Default, - &default_logger_id}; - return &default_logging_manager; - } + // operator Arg2() { + // static LoggingManager default_logging_manager{std::unique_ptr{new CErrSink{}}, + // Severity::kWARNING, false, LoggingManager::InstanceType::Default, + // &default_logger_id}; + // return &default_logging_manager; + // } static SessionObjectInitializer Get() { return SessionObjectInitializer(); } }; +std::string SessionObjectInitializer::default_logger_id = "Default"; + inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) { auto p = f.CreateProvider(); OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(p))); @@ -350,17 +351,17 @@ void InitializeSession(InferenceSession* sess, const std::vector& p OrtPybindThrowIfError(sess->Initialize()); } -void addGlobalMethods(py::module& m) { +void addGlobalMethods(py::module& m, const Environment& env) { m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance."); m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer."); m.def( "get_device", []() -> std::string { return BACKEND_DEVICE; }, "Return the device used to compute the prediction (CPU, MKL, ...)"); m.def( - "set_default_logger_severity", [](int severity) { + "set_default_logger_severity", [&env](int severity) { ORT_ENFORCE(severity >= 0 && severity <= 4, "Invalid logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal"); - logging::LoggingManager* default_logging_manager = SessionObjectInitializer::Get(); + logging::LoggingManager* default_logging_manager = env.GetLoggingManager(); default_logging_manager->SetDefaultLoggerSeverity(static_cast(severity)); }, "Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal"); @@ -546,7 +547,7 @@ void addOpSchemaSubmodule(py::module& m) { #endif //onnxruntime_PYBIND_EXPORT_OPSCHEMA -void addObjectMethods(py::module& m) { +void addObjectMethods(py::module& m, Environment& env) { py::enum_(m, "GraphOptimizationLevel") .value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL) .value("ORT_ENABLE_BASIC", GraphOptimizationLevel::ORT_ENABLE_BASIC) @@ -658,72 +659,70 @@ including arg name, arg type (contains both type and shape).)pbdoc") return *(na.Type()); }, "node type") - .def( - "__str__", [](const onnxruntime::NodeArg& na) -> std::string { - std::ostringstream res; - res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape="; - auto shape = na.Shape(); - std::vector arr; - if (shape == nullptr || shape->dim_size() == 0) { - res << "[]"; + .def("__str__", [](const onnxruntime::NodeArg& na) -> std::string { + std::ostringstream res; + res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape="; + auto shape = na.Shape(); + std::vector arr; + if (shape == nullptr || shape->dim_size() == 0) { + res << "[]"; + } else { + res << "["; + for (int i = 0; i < shape->dim_size(); ++i) { + if (utils::HasDimValue(shape->dim(i))) { + res << shape->dim(i).dim_value(); + } else if (utils::HasDimParam(shape->dim(i))) { + res << "'" << shape->dim(i).dim_param() << "'"; } else { - res << "["; - for (int i = 0; i < shape->dim_size(); ++i) { - if (utils::HasDimValue(shape->dim(i))) { - res << shape->dim(i).dim_value(); - } else if (utils::HasDimParam(shape->dim(i))) { - res << "'" << shape->dim(i).dim_param() << "'"; - } else { - res << "None"; - } - - if (i < shape->dim_size() - 1) { - res << ", "; - } - } - res << "]"; + res << "None"; } - res << ")"; - return std::string(res.str()); - }, - "converts the node into a readable string") - .def_property_readonly( - "shape", [](const onnxruntime::NodeArg& na) -> std::vector { - auto shape = na.Shape(); - std::vector arr; - if (shape == nullptr || shape->dim_size() == 0) { - return arr; + if (i < shape->dim_size() - 1) { + res << ", "; } + } + res << "]"; + } + res << ")"; - arr.resize(shape->dim_size()); - for (int i = 0; i < shape->dim_size(); ++i) { - if (utils::HasDimValue(shape->dim(i))) { - arr[i] = py::cast(shape->dim(i).dim_value()); - } else if (utils::HasDimParam(shape->dim(i))) { - arr[i] = py::cast(shape->dim(i).dim_param()); - } else { - arr[i] = py::none(); - } - } - return arr; - }, - "node shape (assuming the node holds a tensor)"); + return std::string(res.str()); + }, + "converts the node into a readable string") + .def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector { + auto shape = na.Shape(); + std::vector arr; + if (shape == nullptr || shape->dim_size() == 0) { + return arr; + } + + arr.resize(shape->dim_size()); + for (int i = 0; i < shape->dim_size(); ++i) { + if (utils::HasDimValue(shape->dim(i))) { + arr[i] = py::cast(shape->dim(i).dim_value()); + } else if (utils::HasDimParam(shape->dim(i))) { + arr[i] = py::cast(shape->dim(i).dim_param()); + } else { + arr[i] = py::none(); + } + } + return arr; + }, + "node shape (assuming the node holds a tensor)"); py::class_(m, "SessionObjectInitializer"); py::class_(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc") // In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char* // without any conversion. So this init method can be used for model file path (string) // and model content (bytes) - .def(py::init([](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) { + .def(py::init([&env](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) { // Given arg is the file path. Invoke the corresponding ctor(). if (is_arg_file_name) { - return onnxruntime::make_unique(so, arg, SessionObjectInitializer::Get()); + return onnxruntime::make_unique(so, env, arg); } // Given arg is the model content as bytes. Invoke the corresponding ctor(). std::istringstream buffer(arg); - return onnxruntime::make_unique(so, buffer, SessionObjectInitializer::Get()); + return onnxruntime::make_unique(so, env, buffer); })) .def( "load_model", [](InferenceSession* sess, std::vector& provider_types) { @@ -867,6 +866,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { #endif + static std::unique_ptr env; auto initialize = [&]() { // Initialization of the module ([]() -> void { @@ -874,8 +874,11 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { import_array1(); })(); - static std::unique_ptr env; - OrtPybindThrowIfError(Environment::Create(env)); + OrtPybindThrowIfError(Environment::Create(std::make_unique( + std::unique_ptr{new CErrSink{}}, + Severity::kWARNING, false, LoggingManager::InstanceType::Default, + &SessionObjectInitializer::default_logger_id), + env)); static bool initialized = false; if (initialized) { @@ -885,8 +888,8 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { }; initialize(); - addGlobalMethods(m); - addObjectMethods(m); + addGlobalMethods(m, *env); + addObjectMethods(m, *env); #ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA addOpSchemaSubmodule(m); diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index 15e4ce84eefbc..288cf1cb690d2 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -37,7 +37,7 @@ typedef std::vector ArgMap; class FenceCudaTestInferenceSession : public InferenceSession { public: - FenceCudaTestInferenceSession(const SessionOptions& so) : InferenceSession(so) {} + FenceCudaTestInferenceSession(const SessionOptions& so, const Environment& env) : InferenceSession(so, env) {} Status LoadModel(onnxruntime::Model& model) { auto model_proto = model.ToProto(); auto st = Load(model_proto); @@ -117,7 +117,7 @@ TEST(CUDAFenceTests, DISABLED_PartOnCPU) { DataTypeImpl::GetType()->GetDeleteFunc()); SessionOptions so; - FenceCudaTestInferenceSession session(so); + FenceCudaTestInferenceSession session(so, GetEnvironment()); LoadInferenceSessionFromModel(session, *model); CUDAExecutionProviderInfo xp_info; session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info)); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index d26be7a40ed21..a2170ae4abe35 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -1536,7 +1536,7 @@ TEST(InferenceSessionTests, TestParallelExecutionWithCudaProvider) { SessionOptions so; so.execution_mode = ExecutionMode::ORT_PARALLEL; so.session_logid = "InferenceSessionTests.TestParallelExecutionWithCudaProvider"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; CUDAExecutionProviderInfo epi; epi.device_id = 0; @@ -1624,7 +1624,7 @@ TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { std::string model_path = "testdata/model_with_valid_ort_config_json.onnx"; // Create session - InferenceSession session_object_1{so, model_path, GetEnvironment()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; // Load() and Initialize() the session Status st; @@ -1662,7 +1662,7 @@ TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { so.intra_op_num_threads = 2; // Create session - InferenceSession session_object_2{so, model_path, GetEnvironment()}; + InferenceSession session_object_2{so, GetEnvironment(), model_path}; // Load() and Initialize() the session ASSERT_TRUE((st = session_object_2.Load()).IsOK()) << st.ErrorMessage(); @@ -1691,7 +1691,7 @@ TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { // Create session (should throw as the json within the model is invalid/improperly formed) try { - InferenceSession session_object_1{so, model_path, GetEnvironment()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; } catch (const std::exception& e) { std::string e_message(std::string(e.what())); ASSERT_TRUE(e_message.find("Could not finalize session options while constructing the inference session. Error Message:") != std::string::npos); @@ -1710,7 +1710,7 @@ TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { so.intra_op_num_threads = 2; // Create session - InferenceSession session_object_2{so, model_path, GetEnvironment()}; + InferenceSession session_object_2{so, GetEnvironment(), model_path}; // Load() and Initialize() the session Status st; @@ -1740,7 +1740,7 @@ TEST(InferenceSessionTests, LoadModelWithNoOrtConfigJson) { std::string model_path = "testdata/transform/abs-id-max.onnx"; // Create session - InferenceSession session_object_1{so, model_path, GetEnvironment()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; // Load() and Initialize() the session Status st; @@ -1761,7 +1761,7 @@ TEST(InferenceSessionTests, LoadModelWithNoOrtConfigJson) { #endif // Create session - InferenceSession session_object_2{so, model_path, GetEnvironment()}; // so has inter_op_num_threads set to 2 + InferenceSession session_object_2{so, GetEnvironment(), model_path}; // so has inter_op_num_threads set to 2 // Load() and Initialize() the session ASSERT_TRUE((st = session_object_2.Load()).IsOK()) << st.ErrorMessage(); @@ -1785,7 +1785,7 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) { // Create session (should throw because of the unsupported value for the env var - ORT_LOAD_CONFIG_FROM_MODEL) try { - InferenceSession session_object_1{so, model_path, GetEnvironment()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; } catch (const std::exception& e) { std::string e_message(std::string(e.what())); ASSERT_TRUE(e_message.find("Could not finalize session options while constructing the inference session. Error Message:") != std::string::npos); diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 015c4efed514f..3bcf1263d2a59 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -87,7 +87,7 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) { RunOptions run_options; run_options.run_tag = so.session_logid; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; TensorrtExecutionProviderInfo epi; epi.device_id = 0; @@ -201,7 +201,7 @@ TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) { RunOptions run_options; run_options.run_tag = so.session_logid; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; TensorrtExecutionProviderInfo epi; epi.device_id = 0; @@ -215,11 +215,10 @@ TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) { // Now run status = session_object.Run(run_options, feeds, output_names, &fetches); ASSERT_TRUE(status.IsOK()); - std::vector fetche {fetches.back()}; + std::vector fetche{fetches.back()}; VerifyOutputs(fetche, expected_dims_mul_n, expected_values_mul_n); } - TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { onnxruntime::Model model("graph_removecycleTest", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); diff --git a/winml/adapter/winml_adapter_environment.cpp b/winml/adapter/winml_adapter_environment.cpp index d74c35aa3344f..51b93a8cdb6bb 100644 --- a/winml/adapter/winml_adapter_environment.cpp +++ b/winml/adapter/winml_adapter_environment.cpp @@ -14,8 +14,8 @@ #include "abi_custom_registry_impl.h" #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" #include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" -#endif USE_DML +#endif USE_DML namespace winmla = Windows::AI::MachineLearning::Adapter; class WinmlAdapterLoggingWrapper : public LoggingWrapper { diff --git a/winml/adapter/winml_adapter_session.cpp b/winml/adapter/winml_adapter_session.cpp index 329d713752f70..b95b2e6c8d762 100644 --- a/winml/adapter/winml_adapter_session.cpp +++ b/winml/adapter/winml_adapter_session.cpp @@ -42,7 +42,7 @@ ORT_API_STATUS_IMPL(winmla::CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ co std::unique_ptr inference_session; try { // Create the inference session - inference_session = std::make_unique(options->value, env->GetLoggingManager()); + inference_session = std::make_unique(options->value, env->GetEnvironment()); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } @@ -171,7 +171,7 @@ GetLotusCustomRegistries(IMLOperatorRegistry* registry) { // Get the ORT registry return abi_custom_registry->GetRegistries(); -#endif // USE_DML +#endif // USE_DML } return {}; } @@ -195,7 +195,7 @@ ORT_API_STATUS_IMPL(winmla::CreateCustomRegistry, _Out_ IMLOperatorRegistry** re #ifdef USE_DML auto impl = wil::MakeOrThrow(); *registry = impl.Detach(); -#endif // USE_DML +#endif // USE_DML return nullptr; API_IMPL_END }