Skip to content

Commit

Permalink
On-Device Training - Enable loading from buffer (#16417)
Browse files Browse the repository at this point in the history
  • Loading branch information
askhade authored Aug 23, 2023
1 parent ae62d75 commit 56102ec
Show file tree
Hide file tree
Showing 23 changed files with 524 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public struct OrtTrainingApi
public IntPtr LoadCheckpoint;
public IntPtr SaveCheckpoint;
public IntPtr CreateTrainingSession;
public IntPtr CreateTrainingSessionFromBuffer;
public IntPtr TrainingSessionGetTrainingModelOutputCount;
public IntPtr TrainingSessionGetEvalModelOutputCount;
public IntPtr TrainingSessionGetTrainingModelOutputName;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<Float8E5M2FNUZ>

int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);

#ifdef ENABLE_TRAINING_CORE
#ifdef ENABLE_TRAINING
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context);
#endif

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""This file is used to generate test data for ort format model tests in
orttraining/orttraining/test/training_api/core/training_capi_tests.cc."""

import onnx
import torch
import torch.nn as nn

from onnxruntime.training import artifacts


class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out


def model_export(pt_model, model_path, input_size):
# Generate random input data
input_data = torch.randn(32, input_size)
torch.onnx.export(
pt_model,
input_data,
model_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)


def main():
# Set the dimensions for input, hidden, and output layers
input_size = 10
hidden_size = 20
output_size = 5

# Create an instance of the neural network
pt_model = SimpleNet(input_size, hidden_size, output_size)

train_model_path = "simplenet_training.onnx"
model_export(pt_model, train_model_path, input_size)

onnx_model = onnx.load(train_model_path)

requires_grad = ["fc2.weight", "fc2.bias"]
frozen_params = [param.name for param in onnx_model.graph.initializer if param.name not in requires_grad]

# Generate the training artifacts.
artifacts.generate_artifacts(
onnx_model,
requires_grad=requires_grad,
frozen_params=frozen_params,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
ort_format=True,
)


if __name__ == "__main__":
main()
Binary file not shown.
10 changes: 6 additions & 4 deletions orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,11 @@ struct PyOptimizer {
PyOptimizer(const std::string optimizer_model_uri, onnxruntime::training::api::CheckpointState* state,
std::vector<std::shared_ptr<IExecutionProvider>> providers, PySessionOptions* session_options)
: optimizer_() {
auto model_identifiers = onnxruntime::training::api::ModelIdentifiers("", std::nullopt, optimizer_model_uri);
auto env = GetTrainingEnv().GetORTEnv();
// XXX: We hope that env will be around when optimizer needs it.
optimizer_ = std::make_shared<onnxruntime::training::api::Optimizer>(
optimizer_model_uri, state, session_options->value, *env, providers, session_options->custom_op_domains_);
model_identifiers, state, session_options->value, *env, providers, session_options->custom_op_domains_);
}

std::shared_ptr<onnxruntime::training::api::Optimizer> optimizer_;
Expand Down Expand Up @@ -941,9 +942,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
OrtDevice device, PySessionOptions* session_options) {
std::vector<std::shared_ptr<IExecutionProvider>> provider = GetExecutionProvidersForTrainingApis(device);
auto env = GetTrainingEnv().GetORTEnv();
return std::make_unique<onnxruntime::training::api::Module>(
model_uri, state, session_options->value, *env, provider, eval_model_uri,
session_options->custom_op_domains_);
auto model_identifiers = onnxruntime::training::api::ModelIdentifiers(model_uri, eval_model_uri, std::nullopt);
return std::make_unique<onnxruntime::training::api::Module>(model_identifiers,
state, session_options->value, *env, provider,
session_options->custom_op_domains_);
}))
.def("train_step",
[](onnxruntime::training::api::Module* model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,12 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad) {
#if defined(USE_CUDA)
providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider());
#endif
auto model = std::make_unique<Module>(model_uri, &state, session_option,
auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
std::nullopt,
std::optional<std::string>(onnxruntime::ToUTF8String(optim_uri)));
auto model = std::make_unique<Module>(model_identifier, &state, session_option,
*env, providers);
auto optimizer = std::make_unique<Optimizer>(optim_uri, &state, session_option,
auto optimizer = std::make_unique<Optimizer>(model_identifier, &state, session_option,
*env, providers);

// Remove the temporary directory if it already exists.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,12 @@ void TestModuleExport(const std::vector<std::shared_ptr<IExecutionProvider>>& pr

std::unique_ptr<Environment> env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(training_model_uri),
std::optional<std::string>(onnxruntime::ToUTF8String(eval_model_uri)),
std::nullopt);
auto model = std::make_unique<onnxruntime::training::api::Module>(
ToUTF8String(training_model_uri), &state, onnxruntime::SessionOptions(),
*env, providers, ToUTF8String(eval_model_uri));
model_identifier, &state, onnxruntime::SessionOptions(),
*env, providers);

auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir");
if (Env::Default().FolderExists(test_dir)) {
Expand Down Expand Up @@ -141,7 +144,9 @@ TEST(TrainingApiTest, ModuleParametersSize) {
onnxruntime::SessionOptions session_option;
std::unique_ptr<Environment> env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
auto model = std::make_unique<onnxruntime::training::api::Module>(ToUTF8String(model_uri),
auto model_identifiers = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
std::nullopt, std::nullopt);
auto model = std::make_unique<onnxruntime::training::api::Module>(model_identifiers,
&state, session_option,
*env, std::vector<std::shared_ptr<IExecutionProvider>>());
size_t params_size = 0;
Expand All @@ -164,7 +169,10 @@ TEST(TrainingApiTest, ModuleCopyBufferToParameters) {
onnxruntime::SessionOptions session_option;
std::unique_ptr<Environment> env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
auto model = std::make_unique<onnxruntime::training::api::Module>(ToUTF8String(model_uri),
auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
std::nullopt,
std::nullopt);
auto model = std::make_unique<onnxruntime::training::api::Module>(model_identifier,
&state, session_option,
*env, std::vector<std::shared_ptr<IExecutionProvider>>());
int64_t params_size = static_cast<int64_t>(model->GetParametersSize());
Expand Down Expand Up @@ -202,7 +210,10 @@ TEST(TrainingApiTest, ModuleTrainStep) {
onnxruntime::SessionOptions session_option;
std::unique_ptr<Environment> env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
auto model = std::make_unique<onnxruntime::training::api::Module>(ToUTF8String(model_uri),
auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
std::nullopt,
std::nullopt);
auto model = std::make_unique<onnxruntime::training::api::Module>(model_identifier,
&state, session_option,
*env, std::vector<std::shared_ptr<IExecutionProvider>>());
ASSERT_EQ(model->GetTrainingModelOutputCount(), 1);
Expand Down Expand Up @@ -274,8 +285,12 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) {

ASSERT_STATUS_OK(Environment::Create(nullptr, env));

auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
std::nullopt,
std::optional<std::string>(onnxruntime::ToUTF8String(optim_uri)));

std::shared_ptr<Module> model = std::make_shared<Module>(
ToUTF8String(model_uri), &state, session_option,
model_identifier, &state, session_option,
*env, providers);

// Load state dict from faked optimizer checkpoint state.
Expand All @@ -285,7 +300,7 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) {
{"momentum0", "momentum1"},
external_optimizer_checkpoint_state));
std::shared_ptr<Optimizer> optim = std::make_shared<Optimizer>(
ToUTF8String(optim_uri), &new_state, session_option, *env, providers);
model_identifier, &new_state, session_option, *env, providers);

ASSERT_TRUE(optim.get() != nullptr);
}
Expand Down Expand Up @@ -320,8 +335,12 @@ void TestLRSchduler(const std::basic_string<ORTCHAR_T>& test_file_name,

ASSERT_STATUS_OK(Environment::Create(nullptr, env));

auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
std::nullopt,
std::optional<std::string>(onnxruntime::ToUTF8String(optim_uri)));

std::shared_ptr<Module> model = std::make_shared<Module>(
ToUTF8String(model_uri), &state, session_option,
model_identifier, &state, session_option,
*env, providers);

OrtValue input, target;
Expand Down Expand Up @@ -351,7 +370,7 @@ void TestLRSchduler(const std::basic_string<ORTCHAR_T>& test_file_name,
}

std::shared_ptr<Optimizer> optim = std::make_shared<Optimizer>(
ToUTF8String(optim_uri), &state, session_option,
model_identifier, &state, session_option,
*env, providers);

// KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate.
Expand Down Expand Up @@ -445,11 +464,15 @@ TEST(TrainingApiTest, OptimStep) {
providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider());
#endif
ASSERT_STATUS_OK(Environment::Create(nullptr, env));

auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
std::nullopt,
std::optional<std::string>(onnxruntime::ToUTF8String(optim_uri)));
auto model = std::make_unique<onnxruntime::training::api::Module>(
ToUTF8String(model_uri), &state, session_option,
model_identifier, &state, session_option,
*env, providers);
auto optim = std::make_unique<onnxruntime::training::api::Optimizer>(
ToUTF8String(optim_uri), &state, session_option,
model_identifier, &state, session_option,
*env, providers);

OrtValue input, target;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "gtest/gtest.h"
#include "gmock/gmock.h"

#include "onnxruntime_c_api.h"
#include "onnxruntime_training_c_api.h"
Expand All @@ -16,6 +17,7 @@
namespace onnxruntime::training::test {

#define MODEL_FOLDER ORT_TSTR("testdata/training_api/")
#define ORT_FORMAT_MODEL_FOLDER ORT_TSTR("testdata/training_api/ort_format/")

TEST(TrainingCApiTest, SaveCheckpoint) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
Expand Down Expand Up @@ -220,4 +222,100 @@ TEST(TrainingCApiTest, RegisterCustomOps) {
ASSERT_TRUE(loss.front().IsTensor());
}

TEST(TrainingCApiTest, LoadModelsAndCreateSession) {
auto model_path = MODEL_FOLDER "training_model.onnx";

Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env,
Ort::SessionOptions(),
checkpoint_state,
model_path);
}

TEST(TrainingCApiTest, LoadModelsAndCreateSession_ORTFormat) {
auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort";
auto eval_train_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort";
auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort";

Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint");
Ort::TrainingSession training_session = Ort::TrainingSession(env,
Ort::SessionOptions(),
checkpoint_state,
train_model_path,
eval_train_model_path,
optimizer_model_path);
}

TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) {
auto model_path = MODEL_FOLDER "training_model.onnx";
size_t model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len));
std::vector<uint8_t> train_model_data(model_data_len);
std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(train_model_data.data()), model_data_len);
ASSERT_TRUE(train_model_data.size() == model_data_len);

Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env,
Ort::SessionOptions(),
checkpoint_state,
train_model_data);
}

TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) {
auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort";
auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort";
auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort";
size_t model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(train_model_path, model_data_len));
std::vector<uint8_t> train_model_data(model_data_len);
{
std::ifstream bytes_stream(train_model_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(train_model_data.data()), model_data_len);
ASSERT_TRUE(train_model_data.size() == model_data_len);
}

model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, model_data_len));
std::vector<uint8_t> eval_model_data(model_data_len);
{
std::ifstream bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(eval_model_data.data()), model_data_len);
ASSERT_TRUE(eval_model_data.size() == model_data_len);
}

model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(optimizer_model_path, model_data_len));
std::vector<uint8_t> optimizer_model_data(model_data_len);
{
std::ifstream bytes_stream(optimizer_model_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(optimizer_model_data.data()), model_data_len);
ASSERT_TRUE(optimizer_model_data.size() == model_data_len);
}

Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint");
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(),
checkpoint_state, train_model_data,
eval_model_data, optimizer_model_data);
}

TEST(TrainingCApiTest, LoadModelsFromBufferThrows) {
Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");

try {
std::vector<uint8_t> train_model_data;
Ort::TrainingSession training_session = Ort::TrainingSession(env,
Ort::SessionOptions(),
checkpoint_state,
train_model_data);
} catch (const std::exception& ex) {
ASSERT_THAT(ex.what(),
testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL."));
}
}
} // namespace onnxruntime::training::test
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,29 @@ struct OrtTrainingApi {
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
_Outptr_ OrtTrainingSession** out);
_Outptr_result_maybenull_ OrtTrainingSession** out);

/** \brief Create a training session that can be used to begin or resume training.
* This api provides a way to load all the training artifacts from buffers instead of files.
*
* \param[in] env Environment to be used for the training session.
* \param[in] options Session options that the user can customize for this training session.
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
* \param[in] train_model_data Buffer containing the model data to be used to perform training
* \param[in] train_data_length Length of the buffer containing train_model_data
* \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation
* \param[in] eval_data_length Length of the buffer containing eval_model_data
* \param[in] optim_model_data Buffer containing the model data to be used to perform weight update
* \param[in] optim_data_length Length of the buffer containing optim_model_data
* \param[out] out Created training session.
*
*/
ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
_In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const void* train_model_data, size_t train_data_length,
_In_ const void* eval_model_data, size_t eval_data_length,
_In_ const void* optim_model_data, size_t optim_data_length,
_Outptr_result_maybenull_ OrtTrainingSession** out);

/// @}

Expand Down
Loading

0 comments on commit 56102ec

Please sign in to comment.