diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index c52ca4d1a4631..ac790242409e3 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -15,6 +15,7 @@ public struct OrtTrainingApi public IntPtr LoadCheckpoint; public IntPtr SaveCheckpoint; public IntPtr CreateTrainingSession; + public IntPtr CreateTrainingSessionFromBuffer; public IntPtr TrainingSessionGetTrainingModelOutputCount; public IntPtr TrainingSessionGetEvalModelOutputCount; public IntPtr TrainingSessionGetTrainingModelOutputName; diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 56f41154b719c..ea6a629f87cb8 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -223,7 +223,7 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); -#ifdef ENABLE_TRAINING_CORE +#ifdef ENABLE_TRAINING common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context); #endif diff --git a/onnxruntime/test/testdata/training_api/ort_format/checkpoint b/onnxruntime/test/testdata/training_api/ort_format/checkpoint new file mode 100644 index 0000000000000..ab35c9ad5acde Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/checkpoint differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort b/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort new file mode 100644 index 0000000000000..69b2c7e029de0 Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort b/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort new file mode 100644 index 0000000000000..88f192462362d Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py new file mode 100644 index 0000000000000..70e8c4ac011a9 --- /dev/null +++ b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""This file is used to generate test data for ort format model tests in + orttraining/orttraining/test/training_api/core/training_capi_tests.cc.""" + +import onnx +import torch +import torch.nn as nn + +from onnxruntime.training import artifacts + + +class SimpleNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + +def model_export(pt_model, model_path, input_size): + # Generate random input data + input_data = torch.randn(32, input_size) + torch.onnx.export( + pt_model, + input_data, + model_path, + input_names=["input"], + output_names=["output"], + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + ) + + +def main(): + # Set the dimensions for input, hidden, and output layers + input_size = 10 + hidden_size = 20 + output_size = 5 + + # Create an instance of the neural network + pt_model = SimpleNet(input_size, hidden_size, output_size) + + train_model_path = "simplenet_training.onnx" + model_export(pt_model, train_model_path, input_size) + + onnx_model = onnx.load(train_model_path) + + requires_grad = ["fc2.weight", "fc2.bias"] + frozen_params = [param.name for param in onnx_model.graph.initializer if param.name not in requires_grad] + + # Generate the training artifacts. + artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + frozen_params=frozen_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + ort_format=True, + ) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/testdata/training_api/ort_format/training_model.ort b/onnxruntime/test/testdata/training_api/ort_format/training_model.ort new file mode 100644 index 0000000000000..94bda328a9f9f Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/training_model.ort differ diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index eac17f3d4d2e8..3f3aa396e6ca0 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -174,10 +174,11 @@ struct PyOptimizer { PyOptimizer(const std::string optimizer_model_uri, onnxruntime::training::api::CheckpointState* state, std::vector> providers, PySessionOptions* session_options) : optimizer_() { + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers("", std::nullopt, optimizer_model_uri); auto env = GetTrainingEnv().GetORTEnv(); // XXX: We hope that env will be around when optimizer needs it. optimizer_ = std::make_shared( - optimizer_model_uri, state, session_options->value, *env, providers, session_options->custom_op_domains_); + model_identifiers, state, session_options->value, *env, providers, session_options->custom_op_domains_); } std::shared_ptr optimizer_; @@ -941,9 +942,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn OrtDevice device, PySessionOptions* session_options) { std::vector> provider = GetExecutionProvidersForTrainingApis(device); auto env = GetTrainingEnv().GetORTEnv(); - return std::make_unique( - model_uri, state, session_options->value, *env, provider, eval_model_uri, - session_options->custom_op_domains_); + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers(model_uri, eval_model_uri, std::nullopt); + return std::make_unique(model_identifiers, + state, session_options->value, *env, provider, + session_options->custom_op_domains_); })) .def("train_step", [](onnxruntime::training::api::Module* model, diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index 4fa3844717ef9..1369c9c69865a 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -331,9 +331,12 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad) { #if defined(USE_CUDA) providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); #endif - auto model = std::make_unique(model_uri, &state, session_option, + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + auto model = std::make_unique(model_identifier, &state, session_option, *env, providers); - auto optimizer = std::make_unique(optim_uri, &state, session_option, + auto optimizer = std::make_unique(model_identifier, &state, session_option, *env, providers); // Remove the temporary directory if it already exists. diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index ec0c7a1968ba4..2170f7957e6a6 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -76,9 +76,12 @@ void TestModuleExport(const std::vector>& pr std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(training_model_uri), + std::optional(onnxruntime::ToUTF8String(eval_model_uri)), + std::nullopt); auto model = std::make_unique( - ToUTF8String(training_model_uri), &state, onnxruntime::SessionOptions(), - *env, providers, ToUTF8String(eval_model_uri)); + model_identifier, &state, onnxruntime::SessionOptions(), + *env, providers); auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir"); if (Env::Default().FolderExists(test_dir)) { @@ -141,7 +144,9 @@ TEST(TrainingApiTest, ModuleParametersSize) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifiers = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, std::nullopt); + auto model = std::make_unique(model_identifiers, &state, session_option, *env, std::vector>()); size_t params_size = 0; @@ -164,7 +169,10 @@ TEST(TrainingApiTest, ModuleCopyBufferToParameters) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::nullopt); + auto model = std::make_unique(model_identifier, &state, session_option, *env, std::vector>()); int64_t params_size = static_cast(model->GetParametersSize()); @@ -202,7 +210,10 @@ TEST(TrainingApiTest, ModuleTrainStep) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::nullopt); + auto model = std::make_unique(model_identifier, &state, session_option, *env, std::vector>()); ASSERT_EQ(model->GetTrainingModelOutputCount(), 1); @@ -274,8 +285,12 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) { ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + std::shared_ptr model = std::make_shared( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); // Load state dict from faked optimizer checkpoint state. @@ -285,7 +300,7 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) { {"momentum0", "momentum1"}, external_optimizer_checkpoint_state)); std::shared_ptr optim = std::make_shared( - ToUTF8String(optim_uri), &new_state, session_option, *env, providers); + model_identifier, &new_state, session_option, *env, providers); ASSERT_TRUE(optim.get() != nullptr); } @@ -320,8 +335,12 @@ void TestLRSchduler(const std::basic_string& test_file_name, ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + std::shared_ptr model = std::make_shared( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); OrtValue input, target; @@ -351,7 +370,7 @@ void TestLRSchduler(const std::basic_string& test_file_name, } std::shared_ptr optim = std::make_shared( - ToUTF8String(optim_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); // KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate. @@ -445,11 +464,15 @@ TEST(TrainingApiTest, OptimStep) { providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); #endif ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); auto model = std::make_unique( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); auto optim = std::make_unique( - ToUTF8String(optim_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); OrtValue input, target; diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index e864f3b8632de..d734be8e3474b 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "gmock/gmock.h" #include "onnxruntime_c_api.h" #include "onnxruntime_training_c_api.h" @@ -16,6 +17,7 @@ namespace onnxruntime::training::test { #define MODEL_FOLDER ORT_TSTR("testdata/training_api/") +#define ORT_FORMAT_MODEL_FOLDER ORT_TSTR("testdata/training_api/ort_format/") TEST(TrainingCApiTest, SaveCheckpoint) { auto model_uri = MODEL_FOLDER "training_model.onnx"; @@ -220,4 +222,100 @@ TEST(TrainingCApiTest, RegisterCustomOps) { ASSERT_TRUE(loss.front().IsTensor()); } +TEST(TrainingCApiTest, LoadModelsAndCreateSession) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + model_path); +} + +TEST(TrainingCApiTest, LoadModelsAndCreateSession_ORTFormat) { + auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; + auto eval_train_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; + auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_path, + eval_train_model_path, + optimizer_model_path); +} + +TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len)); + std::vector train_model_data(model_data_len); + std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data); +} + +TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) { + auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; + auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; + auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(train_model_path, model_data_len)); + std::vector train_model_data(model_data_len); + { + std::ifstream bytes_stream(train_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + } + + model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, model_data_len)); + std::vector eval_model_data(model_data_len); + { + std::ifstream bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(eval_model_data.data()), model_data_len); + ASSERT_TRUE(eval_model_data.size() == model_data_len); + } + + model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(optimizer_model_path, model_data_len)); + std::vector optimizer_model_data(model_data_len); + { + std::ifstream bytes_stream(optimizer_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(optimizer_model_data.data()), model_data_len); + ASSERT_TRUE(optimizer_model_data.size() == model_data_len); + } + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), + checkpoint_state, train_model_data, + eval_model_data, optimizer_model_data); +} + +TEST(TrainingCApiTest, LoadModelsFromBufferThrows) { + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + + try { + std::vector train_model_data; + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data); + } catch (const std::exception& ex) { + ASSERT_THAT(ex.what(), + testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL.")); + } +} } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index b3042c449a50b..0af737074964d 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -190,7 +190,29 @@ struct OrtTrainingApi { ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_ OrtTrainingSession** out); + _Outptr_result_maybenull_ OrtTrainingSession** out); + + /** \brief Create a training session that can be used to begin or resume training. + * This api provides a way to load all the training artifacts from buffers instead of files. + * + * \param[in] env Environment to be used for the training session. + * \param[in] options Session options that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_data Buffer containing the model data to be used to perform training + * \param[in] train_data_length Length of the buffer containing train_model_data + * \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation + * \param[in] eval_data_length Length of the buffer containing eval_model_data + * \param[in] optim_model_data Buffer containing the model data to be used to perform weight update + * \param[in] optim_data_length Length of the buffer containing optim_model_data + * \param[out] out Created training session. + * + */ + ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out); /// @} diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 5bfdfcc74e817..0edef20ba6da8 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -176,6 +176,20 @@ class TrainingSession : public detail::Base { const std::optional>& eval_model_path = std::nullopt, const std::optional>& optimizer_model_path = std::nullopt); + /** \brief Create a training session that can be used to begin or resume training. + * This constructor allows the users to load the models from buffers instead of files. + * + * \param[in] env Env to be used for the training session. + * \param[in] session_options SessionOptions that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_data Buffer containing training model data. + * \param[in] eval_model_data Buffer containing evaluation model data. + * \param[in] optim_model_data Buffer containing optimizer model (used for performing weight/parameter update). + * + */ + TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, + const std::vector& train_model_data, const std::vector& eval_model_data = {}, + const std::vector& optim_model_data = {}); /// @} /// \name Implementing The Training Loop diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 393e5b01f7f85..066147708863f 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -24,6 +24,23 @@ inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& se ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); } +inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, + CheckpointState& checkpoint_state, + const std::vector& train_model_data, + const std::vector& eval_model_data, + const std::vector& optim_model_data) { + ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer( + env, session_options, checkpoint_state, + train_model_data.data(), train_model_data.size(), + eval_model_data.data(), eval_model_data.size(), + optim_model_data.data(), optim_model_data.size(), + &p_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); +} + inline std::vector TrainingSession::TrainStep(const std::vector& input_values) { std::vector output_values; output_values.reserve(training_model_output_count_); diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 29300bbb7e8ec..d1775e358163c 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -12,7 +12,6 @@ #include "core/graph/graph_utils.h" #include "orttraining/training_api/checkpoint.h" -#include "orttraining/training_api/utils.h" using namespace onnxruntime; @@ -150,12 +149,11 @@ Status Parameter::ResetGrad() { return Status::OK(); } -Module::Module(const std::string& train_model_path_or_bytes, +Module::Module(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, - const std::optional& eval_model_path_or_bytes, [[maybe_unused]] gsl::span op_domains) : state_{state} { // Enforce weight prepacking is disabled @@ -176,7 +174,12 @@ Module::Module(const std::string& train_model_path_or_bytes, } #endif - ORT_THROW_IF_ERROR(train_sess_->Load(train_model_path_or_bytes)); + // Load the training model + ORT_THROW_IF_ERROR(std::holds_alternative(model_identifiers.train_model) + ? train_sess_->Load(std::get(model_identifiers.train_model)) + : train_sess_->Load(std::get>(model_identifiers.train_model).data(), + static_cast(std::get>(model_identifiers.train_model).size()))); + for (const auto& provider : providers) { ORT_THROW_IF_ERROR(train_sess_->RegisterExecutionProvider(provider)); } @@ -239,7 +242,6 @@ Module::Module(const std::string& train_model_path_or_bytes, // Copy ortvalue buffer from CPU to target_device for this "param_name" (based on graph partitioning) // Only copies data if the target device is not the same as the current device the buffer is placed on - OrtValue& param_data = params_iter->second->Data(); ORT_ENFORCE(param_data.IsTensor()); const Tensor& param_data_tensor = param_data.Get(); @@ -278,47 +280,57 @@ Module::Module(const std::string& train_model_path_or_bytes, } } - if (eval_model_path_or_bytes.has_value()) { + if (model_identifiers.IsEvalModelAvailable()) { eval_sess_ = std::make_unique(session_options, env); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) if (!op_domains.empty()) { ORT_THROW_IF_ERROR(eval_sess_->AddCustomOpDomains(op_domains)); } #endif - - ORT_THROW_IF_ERROR(eval_sess_->Load(eval_model_path_or_bytes.value())); - for (const auto& provider : providers) { - ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider)); - } - ORT_THROW_IF_ERROR(eval_sess_->Initialize()); - utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); - - // Eval model validation - // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval - // graphs, and all the weights present in both graphs match. - // TODO: Add the checks instead of making assumptions?? - InlinedVector eval_user_input_names, eval_param_input_names; - for (const auto& input_name : eval_input_names_) { - if (state_->module_checkpoint_state.named_parameters.find(input_name) != - state_->module_checkpoint_state.named_parameters.end()) { - // it is a parameter - eval_param_input_names.emplace_back(input_name); - continue; - } else { - // It is user input. We handle user inputs separately in the eval - // because the eval graph might have different user inputs. - // Eg if loss is not a part of the eval graph, it won't have - // certain inputs like targets - eval_user_input_names.emplace_back(input_name); - } + if (std::holds_alternative>(model_identifiers.eval_model)) { + ORT_THROW_IF_ERROR(eval_sess_->Load(std::get>(model_identifiers.eval_model).value())); + } else { + auto model_data = std::get>(model_identifiers.eval_model); + ORT_THROW_IF_ERROR(eval_sess_->Load(model_data.data(), static_cast(model_data.size()))); } - eval_input_names_ = eval_user_input_names; - eval_user_input_count_ = eval_user_input_names.size(); - eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); + } else { + return; + } - // Keep a copy of the eval model path to be able to later export the model for inferencing. - // The inference model will be reconstructed from the eval model. - eval_model_path_ = eval_model_path_or_bytes.value(); + for (const auto& provider : providers) { + ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider)); + } + ORT_THROW_IF_ERROR(eval_sess_->Initialize()); + utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); + + // Eval model validation + // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval + // graphs, and all the weights present in both graphs match. + // TODO(askhade): Add the checks instead of making assumptions?? + InlinedVector eval_user_input_names, eval_param_input_names; + for (const auto& input_name : eval_input_names_) { + if (state_->module_checkpoint_state.named_parameters.find(input_name) != + state_->module_checkpoint_state.named_parameters.end()) { + // it is a parameter + eval_param_input_names.emplace_back(input_name); + continue; + } else { + // It is user input. We handle user inputs separately in the eval + // because the eval graph might have different user inputs. + // Eg if loss is not a part of the eval graph, it won't have + // certain inputs like targets + eval_user_input_names.emplace_back(input_name); + } + } + eval_input_names_ = eval_user_input_names; + eval_user_input_count_ = eval_user_input_names.size(); + eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); + + // Keep a copy of the eval model path to be able to later export the model for inferencing. + // The inference model will be reconstructed from the eval model. + // TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer. + if (std::holds_alternative>(model_identifiers.eval_model)) { + eval_model_path_ = std::get>(model_identifiers.eval_model); } } @@ -486,14 +498,14 @@ Status Module::EvalStep(const std::vector& inputs, std::vector graph_output_names) const { - ORT_RETURN_IF(!eval_sess_ || eval_model_path_.empty(), + ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(), "Eval model was not provided. Cannot export a model for inferencing."); ONNX_NAMESPACE::ModelProto eval_model; - ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_), eval_model)); + ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model)); // Clone the eval mode into an inference onnxruntime::Model. std::shared_ptr inference_model; diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 9013ab22c124f..adb633343263e 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -3,7 +3,9 @@ #pragma once +#include #include "core/session/inference_session.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime { namespace training { @@ -73,12 +75,12 @@ struct Module { public: // Initialize a module from an ORT inference session with loaded // training ONNX model and load parameters - Module(const std::string& train_model_path_or_bytes, + // The model and checkpoint state can be provided as a file path or a byte array + Module(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, - const std::optional& eval_model_path_or_bytes = std::nullopt, gsl::span op_domains = gsl::span()); // Return the trainable/nontrainable parameters @@ -159,7 +161,7 @@ struct Module { CheckpointState* state_; // Non owning pointer to the state. bool accumulate_gradient_ = false; - std::string eval_model_path_; + std::optional eval_model_path_; size_t train_user_input_count_{0U}; size_t eval_user_input_count_{0U}; }; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index b84009e7f3591..6693bba348648 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -13,6 +13,8 @@ #include "orttraining/training_api/ort_training_apis.h" #include "orttraining/training_api/training_session.h" +using namespace onnxruntime::training::api; + namespace { std::vector> CreateProviders( @@ -26,44 +28,85 @@ std::vector> CreateProviders( return execution_providers; } +static OrtStatus* CreateSessionAndLoadModel(_In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, + _Inout_ OrtCheckpointState* checkpoint_state, + const ModelIdentifiers& model_identifiers, + std::unique_ptr& train_sess) { + auto chkpt_state = reinterpret_cast(checkpoint_state); + + using ProvidersType = std::vector>; + train_sess = std::make_unique(env->GetEnvironment(), + options == nullptr ? onnxruntime::SessionOptions() : options->value, + options == nullptr + ? ProvidersType() + : CreateProviders(options->provider_factories), + chkpt_state, + model_identifiers, + options == nullptr + ? gsl::span() + : options->custom_op_domains_); + + return nullptr; +} + } // namespace ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, - _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_ OrtTrainingSession** out) { + _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_result_maybenull_ OrtTrainingSession** out) { API_IMPL_BEGIN if (options != nullptr && options->value.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseEnvAllocators, "0") == "1") { return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Use Env Allocators is not supported for on device training."); } std::unique_ptr train_sess; - auto chkpt_state = reinterpret_cast(checkpoint_state); OrtStatus* status = nullptr; *out = nullptr; - ORT_TRY { - using ProvidersType = std::vector>; - train_sess = std::make_unique( - env->GetEnvironment(), - options == nullptr ? onnxruntime::SessionOptions() : options->value, - options == nullptr ? ProvidersType() : CreateProviders(options->provider_factories), - chkpt_state, - onnxruntime::training::api::ModelIdentifiers( - onnxruntime::ToUTF8String(train_model_path), - eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path)) - : std::nullopt, - optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path)) - : std::nullopt), - options == nullptr ? gsl::span() : options->custom_op_domains_); - - *out = reinterpret_cast(train_sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } + ORT_ENFORCE(train_model_path != nullptr, + "Train model path is required to create TrainingSession, it cannot be empty."); + + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers( + onnxruntime::ToUTF8String(train_model_path), + eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path)) + : std::nullopt, + optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path)) + : std::nullopt); + + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(env, options, checkpoint_state, model_identifiers, train_sess)); + *out = reinterpret_cast(train_sess.release()); + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out) { + API_IMPL_BEGIN + std::unique_ptr train_sess; + OrtStatus* status = nullptr; + *out = nullptr; + ORT_ENFORCE(train_model_data != nullptr && train_data_length != 0, + "Training Session Creation failed. Train model data cannot be NULL."); + + auto model_identifiers = ModelIdentifiers(gsl::make_span(reinterpret_cast(train_model_data), + train_data_length), + eval_data_length == 0 || eval_model_data == nullptr + ? gsl::span() + : gsl::make_span(reinterpret_cast(eval_model_data), + eval_data_length), + optim_data_length == 0 || optim_model_data == nullptr + ? gsl::span() + : gsl::make_span(reinterpret_cast(optim_model_data), + optim_data_length)); + + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(env, options, checkpoint_state, model_identifiers, train_sess)); + *out = reinterpret_cast(train_sess.release()); return status; API_IMPL_END } @@ -523,6 +566,7 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::LoadCheckpoint, &OrtTrainingApis::SaveCheckpoint, &OrtTrainingApis::CreateTrainingSession, + &OrtTrainingApis::CreateTrainingSessionFromBuffer, &OrtTrainingApis::TrainingSessionGetTrainingModelOutputCount, &OrtTrainingApis::TrainingSessionGetEvalModelOutputCount, &OrtTrainingApis::TrainingSessionGetTrainingModelOutputName, diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index a6b82f1d50fc0..7f583ce8f6e76 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -61,19 +61,10 @@ Status GraphInputsAreExpected(gsl::span actual_graph_inputs, } // namespace std::unique_ptr OptimizerAlorithmFactory::CreateInstance( - const std::string& optim_path, int32_t& group_count) { + std::shared_ptr model, int32_t& group_count) { std::map, int32_t> opt_type_to_freq_map; #if !defined(ORT_MINIMAL_BUILD) - if (const auto optim_path_str = ToPathString(optim_path); - fbs::utils::IsOrtFormatModel(optim_path_str)) { - // TODO (baijumeswani): Figure out the best way to extract the optimizer type - // from an ort format model. - opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; - } else { - std::shared_ptr model; - ORT_ENFORCE(Model::Load(optim_path_str, model, nullptr, - logging::LoggingManager::DefaultLogger()) - .IsOK()); + if (model != nullptr) { Graph& graph = model->MainGraph(); for (auto& node : graph.Nodes()) { if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) { @@ -85,33 +76,71 @@ std::unique_ptr OptimizerAlorithmFactory::CreateInstance opt_type_to_freq_map[domain_type_pair] += 1; } } - } + } else { #else - // TODO (baijumeswani): Figure out the best way to extract the optimizer type - // from the model (either onnx model or ort format model) or from the checkpoint. - // For now, assume that the optimizer type is AdamWOptimizer in a minimal build. - ORT_UNUSED_PARAMETER(optim_path); - - opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; + ORT_UNUSED_PARAMETER(model); +#endif + // TODO(baijumeswani): Figure out the best way to extract the optimizer type + // from the model (either onnx model or ort format model) or from the checkpoint. + // For now, assume that the optimizer type is AdamWOptimizer when using ort format models. + opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; +#if !defined(ORT_MINIMAL_BUILD) + } #endif ORT_ENFORCE(opt_type_to_freq_map.size() == 1U, "Only support one type of optimizer algorithm, but got: " + std::to_string(opt_type_to_freq_map.size())); auto opt_it = opt_type_to_freq_map.begin(); + auto& op_type = opt_it->first.second; group_count = opt_it->second; - auto& domain = opt_it->first.first; - auto& type = opt_it->first.second; + ORT_ENFORCE(group_count == 1, "Group count can only be 1, but got: " + std::to_string(group_count)); // TODO: to support multiple groups, need to create a mapping between each group to its parameter list. - if (domain == kMSDomain && type == "AdamWOptimizer") { + if (op_type == "AdamWOptimizer") { return std::make_unique(); - } else if (domain == kMSDomain && type == "SGDOptimizerV2") { + } else if (op_type == "SGDOptimizerV2") { return std::make_unique(); } else { ORT_NOT_IMPLEMENTED("Not implemented for optimizer algo: " + opt_it->first.second); } } +std::unique_ptr OptimizerAlorithmFactory::CreateInstance( + const PathString& optim_path, int32_t& group_count) { + std::shared_ptr model = nullptr; +#if !defined(ORT_MINIMAL_BUILD) + if (!fbs::utils::IsOrtFormatModel(optim_path)) { + ORT_ENFORCE(Model::Load(optim_path, model, nullptr, + logging::LoggingManager::DefaultLogger()) + .IsOK()); + } +#else + ORT_UNUSED_PARAMETER(optim_path); +#endif + return CreateInstance(model, group_count); +} + +std::unique_ptr OptimizerAlorithmFactory::CreateInstance( + const uint8_t* optim_model_data, size_t optim_model_data_len, int32_t& group_count) { + std::shared_ptr model = nullptr; +#if !defined(ORT_MINIMAL_BUILD) + if (!fbs::utils::IsOrtFormatModelBytes(optim_model_data, static_cast(optim_model_data_len))) { + ONNX_NAMESPACE::ModelProto model_proto; + ORT_ENFORCE(model_proto.ParseFromArray(optim_model_data, static_cast(optim_model_data_len)) == true, + "Failed to load model because protobuf parsing failed."); + + ORT_ENFORCE(Model::Load(std::move(model_proto), model, nullptr, + logging::LoggingManager::DefaultLogger(), ModelOptions(true, true)) + .IsOK()); + } +#else + ORT_UNUSED_PARAMETER(optim_model_data); + ORT_UNUSED_PARAMETER(optim_model_data_len); +#endif + + return CreateInstance(model, group_count); +} + Status Optimizer::GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states) { auto group_optimizer_state_it = optimizer_checkpoint_states.group_named_optimizer_states.find(GROUP_ZERO_NAME); @@ -200,14 +229,14 @@ Status Optimizer::ConstructInputs() { return Status::OK(); } // namespace api -Optimizer::Optimizer(const std::string& optim_path_or_bytes, +Optimizer::Optimizer(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, gsl::span op_domains) : optim_sess_(std::make_unique(session_options, env)), state_(state) { - Initialize(optim_path_or_bytes, providers, op_domains); + Initialize(model_identifiers, providers, op_domains); ORT_ENFORCE(state != nullptr, "Checkpoint state cannot be null."); auto g_it = state_->optimizer_checkpoint_state.group_named_optimizer_states.find(GROUP_ZERO_NAME); @@ -223,7 +252,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes, } } -void Optimizer::Initialize(const std::string& optim_path_or_bytes, +void Optimizer::Initialize(const ModelIdentifiers& model_identifiers, const std::vector>& providers, [[maybe_unused]] gsl::span op_domains) { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) @@ -236,7 +265,22 @@ void Optimizer::Initialize(const std::string& optim_path_or_bytes, ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider)); } - ORT_THROW_IF_ERROR(optim_sess_->Load(optim_path_or_bytes)); + ORT_ENFORCE(model_identifiers.IsOptimizerModelAvailable(), "Optimizer model is not available."); + + if (std::holds_alternative>(model_identifiers.optim_model)) { + auto optimizer_model = std::get>(model_identifiers.optim_model); + // The above call to IsOptimizerModelAvailable() ensures that optimizer_model is not nullopt + ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.value())); + optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(ToWideString(optimizer_model.value()), group_count_); + } else { + auto optimizer_model = std::get>(model_identifiers.optim_model); + ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.data(), + static_cast(optimizer_model.size()))); + optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optimizer_model.data(), + optimizer_model.size(), + group_count_); + } + ORT_THROW_IF_ERROR(optim_sess_->Initialize()); // Make sure that the checkpoint state can copy tensors @@ -244,10 +288,6 @@ void Optimizer::Initialize(const std::string& optim_path_or_bytes, utils::GetGraphInputOutputNames(optim_sess_, input_names_, output_names_); - optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optim_path_or_bytes, group_count_); - ORT_ENFORCE(group_count_ == 1, "Group count can only be 1, but got: " + std::to_string(group_count_)); - ORT_ENFORCE(optimizer_algo_ptr_, "optimizer_algo_ptr_ should not be nullptr."); - InlinedVector all_input_names; all_input_names.reserve(CommonOptimizerInputs.size() + optimizer_algo_ptr_->optimizer_states_inputs.size()); all_input_names.insert(all_input_names.end(), CommonOptimizerInputs.begin(), diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h index 36ce3297fe3c4..d9bc4870bb7ed 100644 --- a/orttraining/orttraining/training_api/optimizer.h +++ b/orttraining/orttraining/training_api/optimizer.h @@ -64,8 +64,11 @@ struct SGDOptimizerV2Algorithm : public OptimizerAlgorithmBase { }; struct OptimizerAlorithmFactory { - static std::unique_ptr CreateInstance(const std::string& optim_path_or_bytes, + static std::unique_ptr CreateInstance(const PathString& optim_path, int32_t& group_count); + static std::unique_ptr CreateInstance(const uint8_t* optim_model_data, + size_t optim_model_data_len, int32_t& group_count); + static std::unique_ptr CreateInstance(std::shared_ptr model, int32_t& group_count); }; struct CheckpointState; @@ -96,7 +99,7 @@ struct Optimizer { // Initialize an optimizer module from an ORT inference session with loaded // training ONNX model For each parameter, initialize the OptimizerState based // on the graph input's ValueInfoProto if the parameter doesn't have it already. - Optimizer(const std::string& optim_path_or_bytes, + Optimizer(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, @@ -121,7 +124,7 @@ struct Optimizer { } private: - void Initialize(const std::string& optim_path_or_bytes, + void Initialize(const ModelIdentifiers& model_identifiers, const std::vector>& providers, gsl::span op_domains); diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index 2b383f3b9782a..c87108957c975 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -8,7 +8,14 @@ ORT_API(const OrtTrainingApi*, GetTrainingApi, uint32_t version); ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_ OrtTrainingSession** out); + _Outptr_result_maybenull_ OrtTrainingSession** out); + +ORT_API_STATUS_IMPL(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out); ORT_API_STATUS_IMPL(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index 6915193a8ff7c..45f0f0ddcf7f4 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "orttraining/training_api/training_session.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime::training::api { @@ -12,13 +13,12 @@ TrainingSession::TrainingSession(const Environment& session_env, const ModelIdentifiers& model_identifiers, gsl::span custom_op_domains) : state_{state}, - module_{std::make_unique(model_identifiers.train_model, state_, - session_options, session_env, providers, - model_identifiers.eval_model, custom_op_domains)}, - optimizer_{model_identifiers.optim_model.has_value() + module_{std::make_unique(model_identifiers, state_, + session_options, session_env, providers, custom_op_domains)}, + optimizer_{model_identifiers.IsOptimizerModelAvailable() ? std::make_unique( - model_identifiers.optim_model.value(), state_, - session_options, session_env, providers, custom_op_domains) + model_identifiers, state_, + session_options, session_env, providers) : std::unique_ptr()} {} Status TrainingSession::RegisterScheduler( diff --git a/orttraining/orttraining/training_api/training_session.h b/orttraining/orttraining/training_api/training_session.h index 1a16acd5115f0..13b0ae79093de 100644 --- a/orttraining/orttraining/training_api/training_session.h +++ b/orttraining/orttraining/training_api/training_session.h @@ -3,25 +3,17 @@ #pragma once #include "core/common/common.h" -#include "module.h" -#include "optimizer.h" -#include "lr_scheduler.h" -#include "checkpoint.h" +#include "orttraining/training_api/module.h" +#include "orttraining/training_api/optimizer.h" +#include "orttraining/training_api/lr_scheduler.h" +#include "orttraining/training_api/checkpoint.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime { namespace training { namespace api { using namespace common; -struct ModelIdentifiers { - const std::string train_model; - const std::optional eval_model, optim_model; - ModelIdentifiers(const std::string& train_model_uri, - const std::optional& eval_model_uri, - const std::optional& optim_model_uri) - : train_model(train_model_uri), eval_model(eval_model_uri), optim_model(optim_model_uri) {} -}; - // Wrapper on top of module and optimizer classes and is the only class exposed via capis class TrainingSession { public: diff --git a/orttraining/orttraining/training_api/utils.h b/orttraining/orttraining/training_api/utils.h index e856554c971ec..f16f0f947fbd5 100644 --- a/orttraining/orttraining/training_api/utils.h +++ b/orttraining/orttraining/training_api/utils.h @@ -10,6 +10,40 @@ namespace onnxruntime { namespace training { namespace api { + +struct ModelIdentifiers { + // ModelIdentifiers struct enables an easy way to store and identify the models used for training, evaluation + // and model updates(optimizer model). + // The model can be specified by a path to the model file or by a span of bytes containing the model data. + // Training model is required, evaluation and optimizer models are optional. + std::variant> train_model; + std::variant, gsl::span> eval_model; + std::variant, gsl::span> optim_model; + + ModelIdentifiers(std::variant> training_model, + std::variant, gsl::span> evaluation_model, + std::variant, gsl::span> optimzer_model) + : train_model(training_model), eval_model(evaluation_model), optim_model(optimzer_model) {} + + bool IsModelAvailable(const std::variant, gsl::span>& model) const { + if ((std::holds_alternative>(model) && + std::get>(model).has_value()) || + (std::holds_alternative>(model) && + std::get>(model).size() > 0)) { + return true; + } + return false; + } + + bool IsEvalModelAvailable() const { + return IsModelAvailable(eval_model); + } + + bool IsOptimizerModelAvailable() const { + return IsModelAvailable(optim_model); + } +}; + namespace utils { // Get names of graph inputs and outputs