Skip to content

Commit

Permalink
expose lr scheduler python bindings for on device training. (#13882)
Browse files Browse the repository at this point in the history
### Description
Exposing LR Scheduler python bindings for on device training.

Co-authored-by: Baiju Meswani <[email protected]>
  • Loading branch information
AdamLouly and baijumeswani authored Dec 23, 2022
1 parent b999022 commit e49f358
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 23 deletions.
19 changes: 9 additions & 10 deletions orttraining/orttraining/python/orttraining_pybind_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,23 @@ namespace py = pybind11;

using namespace onnxruntime::logging;

using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider> >;
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions> > ;
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider>>;
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions>>;


class ORTTrainingPythonEnv{
public:
class ORTTrainingPythonEnv {
public:
ORTTrainingPythonEnv();

Environment& GetORTEnv();

std::shared_ptr<IExecutionProvider> GetExecutionProviderInstance(const std::string& provider_type,
size_t hash);
size_t hash);

void AddExecutionProvider(const std::string& provider_type,
size_t hash,
std::unique_ptr<IExecutionProvider> execution_provider);

void RegisterExtExecutionProviderInfo(const std::string& provider_type,
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
const std::string& provider_lib_path,
const ProviderOptions& default_options);

Expand All @@ -42,7 +41,7 @@ class ORTTrainingPythonEnv{

void ClearExecutionProviderInstances();

private:
private:
std::string GetExecutionProviderMapKey(const std::string& provider_type,
size_t hash);

Expand All @@ -51,5 +50,5 @@ class ORTTrainingPythonEnv{
std::vector<std::string> available_training_eps_;
};

}
}
} // namespace python
} // namespace onnxruntime
51 changes: 39 additions & 12 deletions orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

#ifdef ENABLE_TRAINING_ON_DEVICE
#include "orttraining/training_api/include/checkpoint.h"
#include "core/providers/provider_factory_creators.h"
#include "orttraining/training_api/include/lr_scheduler.h"

#endif

Expand Down Expand Up @@ -164,6 +164,20 @@ struct TrainingConfigurationResult {
optional<std::string> loss_scale_input_name;
};

#ifdef ENABLE_TRAINING_ON_DEVICE
// Thin wrapper over internal C++ Optimizer
struct PyOptimizer {
PyOptimizer(const std::string optimizer_model_uri,
onnxruntime::training::api::Module* model, std::vector<std::shared_ptr<IExecutionProvider>> provider)
: optimizer_(std::make_unique<onnxruntime::training::api::Optimizer>(optimizer_model_uri,
model->NamedParameters(), onnxruntime::SessionOptions(),
GetTrainingORTEnv(), provider)) {
}

std::shared_ptr<onnxruntime::training::api::Optimizer> optimizer_;
};
#endif

struct PyGradientGraphBuilder {
std::unique_ptr<GradientGraphBuilder> builder;
std::shared_ptr<Model> model;
Expand Down Expand Up @@ -917,29 +931,42 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
return state;
}));

py::class_<onnxruntime::training::api::Optimizer>
py::class_<PyOptimizer>
training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc");
training_optimizer.def(py::init([](
const std::string optimizer_model_uri,
onnxruntime::training::api::Module* model,
OrtDevice device) {
onnxruntime::SessionOptions session_option;
std::vector<std::shared_ptr<IExecutionProvider>> provider = GetExecutionProvidersForTrainingApis(device);
return std::make_unique<onnxruntime::training::api::Optimizer>(

return std::make_unique<PyOptimizer>(
optimizer_model_uri,
model->NamedParameters(), session_option,
GetTrainingORTEnv(), provider);
model, provider);
}))
.def("set_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer, float lr) -> void {
ORT_THROW_IF_ERROR(optimizer->SetLearningRate(lr));
.def("optimizer_step", [](PyOptimizer* optimizer) -> void {
ORT_THROW_IF_ERROR(optimizer->optimizer_->Step());
})
.def("get_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer) -> float {
return optimizer->GetLearningRate();
.def("set_learning_rate", [](PyOptimizer* optimizer, float lr) -> void {
ORT_THROW_IF_ERROR(optimizer->optimizer_->SetLearningRate(lr));
})
.def("optimizer_step", [](onnxruntime::training::api::Optimizer* optimizer) -> void {
ORT_THROW_IF_ERROR(optimizer->Step());
.def("get_learning_rate", [](PyOptimizer* optimizer) -> float {
return optimizer->optimizer_->GetLearningRate();
});
py::class_<onnxruntime::training::api::LinearLRScheduler>
lr_scheduler(m, "LinearLRScheduler", R"pbdoc(Learning Rate Scheduler.)pbdoc");
lr_scheduler.def(py::init([](PyOptimizer* optimizer,
int64_t total_step_count,
int64_t warmup_step_count,
float initial_lr) {
ORT_THROW_IF_ERROR(optimizer->optimizer_->SetInitialLearningRate(initial_lr));

return std::make_unique<onnxruntime::training::api::LinearLRScheduler>(
optimizer->optimizer_, warmup_step_count, total_step_count);
}))
.def("scheduler_step", [](onnxruntime::training::api::LinearLRScheduler* scheduler) -> void {
ORT_THROW_IF_ERROR(scheduler->Step());
});

m.def("save_checkpoint",
[](const std::vector<py::bytes>& trainable_tensor_protos_pybytes,
const std::vector<py::bytes>& non_trainable_tensor_protos_pybytes,
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/python/training/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .checkpoint_state import CheckpointState
from .lr_scheduler import LinearLRScheduler
from .module import Module
from .optimizer import Optimizer
34 changes: 34 additions & 0 deletions orttraining/orttraining/python/training/api/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# lr_scheduler.py

from onnxruntime.capi import _pybind_state as C


class LinearLRScheduler:
"""
Linearly updates the learning rate in the optimizer
The linear learning rate scheduler decays the learning rate by linearly updated
multiplicative factor from the initial learning rate set on the training session to 0. The decay
is performed after the initial warm up phase where the learning rate is linearly incremented
from to the initial learning rate provided.
Args:
optimizer (:obj:`training_api.Optimizer`): User's onnxruntime training Optimizer
warmup_step_count (int): The number of steps in the warm up phase.
total_step_count (int): The total number of training steps.
initial_lr (float): The initial learning rate.
"""

def __init__(self, optimizer, warmup_step_count, total_step_count, initial_lr) -> None:

self._scheduler = C.LinearLRScheduler(optimizer._optimizer, warmup_step_count, total_step_count, initial_lr)

def step(self):
"""
The step method of the LinearLRScheduler class is used to update the learning rate of the optimizer according
to the scheduler's strategy.
This method should be called at each step of training to ensure that the learning rate is properly adjusted.
"""
self._scheduler.scheduler_step()
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from orttraining_test_onnxblock import _get_models

import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training.api import CheckpointState, LinearLRScheduler, Module, Optimizer


class SimpleModelWithCrossEntropyLoss(onnxblock.TrainingModel):
Expand Down Expand Up @@ -168,6 +168,41 @@ def test_get_and_set_lr():
assert lr != new_lr


def test_scheduler_step():
# Initialize Models
simple_model, onnx_model, optimizer_model, _, _ = _create_training_models()

# Generating random data for testing.
inputs = torch.randn(64, 784).numpy()
labels = torch.randint(high=10, size=(64,), dtype=torch.int32).numpy()
forward_inputs = [inputs, labels]

with tempfile.TemporaryDirectory() as temp_dir:
# Save models & checkpoint files to load them later.
checkpoint_file_path, model_file_path, optimizer_file_path = _get_test_models_path(
temp_dir, simple_model, onnx_model, optimizer_model=optimizer_model
)
# Create Checkpoint State.
state = CheckpointState(checkpoint_file_path)
# Create a Module and Optimizer.
model = Module(model_file_path, state)
optimizer = Optimizer(optimizer_file_path, model)
scheduler = LinearLRScheduler(optimizer, 1, 2, 0.2)

# Test get and set learning rate.
lr = optimizer.get_learning_rate()
assert np.allclose(lr, 0.0)

model.train()
model(forward_inputs)
optimizer.step()
scheduler.step()

# Get new learning rate.
new_lr = optimizer.get_learning_rate()
assert new_lr != lr


def test_training_module_checkpoint():
# Initialize Models
simple_model, onnx_model, _, _, _ = _create_training_models()
Expand Down

0 comments on commit e49f358

Please sign in to comment.