From e49f358686f7f8ab16e49e4989c5de15d359afa8 Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Thu, 22 Dec 2022 20:44:04 -0600 Subject: [PATCH] expose lr scheduler python bindings for on device training. (#13882) ### Description Exposing LR Scheduler python bindings for on device training. Co-authored-by: Baiju Meswani --- .../python/orttraining_pybind_common.h | 19 ++++--- .../python/orttraining_pybind_state.cc | 51 ++++++++++++++----- .../python/training/api/__init__.py | 1 + .../python/training/api/lr_scheduler.py | 34 +++++++++++++ .../orttraining_test_python_bindings.py | 37 +++++++++++++- 5 files changed, 119 insertions(+), 23 deletions(-) create mode 100644 orttraining/orttraining/python/training/api/lr_scheduler.py diff --git a/orttraining/orttraining/python/orttraining_pybind_common.h b/orttraining/orttraining/python/orttraining_pybind_common.h index 6c208cd6f07a5..c3a5422c22102 100644 --- a/orttraining/orttraining/python/orttraining_pybind_common.h +++ b/orttraining/orttraining/python/orttraining_pybind_common.h @@ -15,24 +15,23 @@ namespace py = pybind11; using namespace onnxruntime::logging; -using ExecutionProviderMap = std::unordered_map >; -using ExecutionProviderLibInfoMap = std::unordered_map > ; +using ExecutionProviderMap = std::unordered_map>; +using ExecutionProviderLibInfoMap = std::unordered_map>; - -class ORTTrainingPythonEnv{ -public: +class ORTTrainingPythonEnv { + public: ORTTrainingPythonEnv(); Environment& GetORTEnv(); std::shared_ptr GetExecutionProviderInstance(const std::string& provider_type, - size_t hash); + size_t hash); void AddExecutionProvider(const std::string& provider_type, size_t hash, std::unique_ptr execution_provider); - void RegisterExtExecutionProviderInfo(const std::string& provider_type, + void RegisterExtExecutionProviderInfo(const std::string& provider_type, const std::string& provider_lib_path, const ProviderOptions& default_options); @@ -42,7 +41,7 @@ class ORTTrainingPythonEnv{ void ClearExecutionProviderInstances(); -private: + private: std::string GetExecutionProviderMapKey(const std::string& provider_type, size_t hash); @@ -51,5 +50,5 @@ class ORTTrainingPythonEnv{ std::vector available_training_eps_; }; -} -} +} // namespace python +} // namespace onnxruntime diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 1dbeb6e6add58..689f3f6688ab2 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -36,7 +36,7 @@ #ifdef ENABLE_TRAINING_ON_DEVICE #include "orttraining/training_api/include/checkpoint.h" -#include "core/providers/provider_factory_creators.h" +#include "orttraining/training_api/include/lr_scheduler.h" #endif @@ -164,6 +164,20 @@ struct TrainingConfigurationResult { optional loss_scale_input_name; }; +#ifdef ENABLE_TRAINING_ON_DEVICE +// Thin wrapper over internal C++ Optimizer +struct PyOptimizer { + PyOptimizer(const std::string optimizer_model_uri, + onnxruntime::training::api::Module* model, std::vector> provider) + : optimizer_(std::make_unique(optimizer_model_uri, + model->NamedParameters(), onnxruntime::SessionOptions(), + GetTrainingORTEnv(), provider)) { + } + + std::shared_ptr optimizer_; +}; +#endif + struct PyGradientGraphBuilder { std::unique_ptr builder; std::shared_ptr model; @@ -917,7 +931,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn return state; })); - py::class_ + py::class_ training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc"); training_optimizer.def(py::init([]( const std::string optimizer_model_uri, @@ -925,21 +939,34 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn OrtDevice device) { onnxruntime::SessionOptions session_option; std::vector> provider = GetExecutionProvidersForTrainingApis(device); - return std::make_unique( + + return std::make_unique( optimizer_model_uri, - model->NamedParameters(), session_option, - GetTrainingORTEnv(), provider); + model, provider); })) - .def("set_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer, float lr) -> void { - ORT_THROW_IF_ERROR(optimizer->SetLearningRate(lr)); + .def("optimizer_step", [](PyOptimizer* optimizer) -> void { + ORT_THROW_IF_ERROR(optimizer->optimizer_->Step()); }) - .def("get_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer) -> float { - return optimizer->GetLearningRate(); + .def("set_learning_rate", [](PyOptimizer* optimizer, float lr) -> void { + ORT_THROW_IF_ERROR(optimizer->optimizer_->SetLearningRate(lr)); }) - .def("optimizer_step", [](onnxruntime::training::api::Optimizer* optimizer) -> void { - ORT_THROW_IF_ERROR(optimizer->Step()); + .def("get_learning_rate", [](PyOptimizer* optimizer) -> float { + return optimizer->optimizer_->GetLearningRate(); + }); + py::class_ + lr_scheduler(m, "LinearLRScheduler", R"pbdoc(Learning Rate Scheduler.)pbdoc"); + lr_scheduler.def(py::init([](PyOptimizer* optimizer, + int64_t total_step_count, + int64_t warmup_step_count, + float initial_lr) { + ORT_THROW_IF_ERROR(optimizer->optimizer_->SetInitialLearningRate(initial_lr)); + + return std::make_unique( + optimizer->optimizer_, warmup_step_count, total_step_count); + })) + .def("scheduler_step", [](onnxruntime::training::api::LinearLRScheduler* scheduler) -> void { + ORT_THROW_IF_ERROR(scheduler->Step()); }); - m.def("save_checkpoint", [](const std::vector& trainable_tensor_protos_pybytes, const std::vector& non_trainable_tensor_protos_pybytes, diff --git a/orttraining/orttraining/python/training/api/__init__.py b/orttraining/orttraining/python/training/api/__init__.py index ecfc74f026b19..c92b8bdbfe46e 100644 --- a/orttraining/orttraining/python/training/api/__init__.py +++ b/orttraining/orttraining/python/training/api/__init__.py @@ -1,3 +1,4 @@ from .checkpoint_state import CheckpointState +from .lr_scheduler import LinearLRScheduler from .module import Module from .optimizer import Optimizer diff --git a/orttraining/orttraining/python/training/api/lr_scheduler.py b/orttraining/orttraining/python/training/api/lr_scheduler.py new file mode 100644 index 0000000000000..cff7eaaa14555 --- /dev/null +++ b/orttraining/orttraining/python/training/api/lr_scheduler.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# lr_scheduler.py + +from onnxruntime.capi import _pybind_state as C + + +class LinearLRScheduler: + """ + Linearly updates the learning rate in the optimizer + + The linear learning rate scheduler decays the learning rate by linearly updated + multiplicative factor from the initial learning rate set on the training session to 0. The decay + is performed after the initial warm up phase where the learning rate is linearly incremented + from to the initial learning rate provided. + + Args: + optimizer (:obj:`training_api.Optimizer`): User's onnxruntime training Optimizer + warmup_step_count (int): The number of steps in the warm up phase. + total_step_count (int): The total number of training steps. + initial_lr (float): The initial learning rate. + """ + + def __init__(self, optimizer, warmup_step_count, total_step_count, initial_lr) -> None: + + self._scheduler = C.LinearLRScheduler(optimizer._optimizer, warmup_step_count, total_step_count, initial_lr) + + def step(self): + """ + The step method of the LinearLRScheduler class is used to update the learning rate of the optimizer according + to the scheduler's strategy. + This method should be called at each step of training to ensure that the learning rate is properly adjusted. + """ + self._scheduler.scheduler_step() diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 78ae34606433c..5ccc6105c701b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -7,7 +7,7 @@ from orttraining_test_onnxblock import _get_models import onnxruntime.training.onnxblock as onnxblock -from onnxruntime.training.api import CheckpointState, Module, Optimizer +from onnxruntime.training.api import CheckpointState, LinearLRScheduler, Module, Optimizer class SimpleModelWithCrossEntropyLoss(onnxblock.TrainingModel): @@ -168,6 +168,41 @@ def test_get_and_set_lr(): assert lr != new_lr +def test_scheduler_step(): + # Initialize Models + simple_model, onnx_model, optimizer_model, _, _ = _create_training_models() + + # Generating random data for testing. + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int32).numpy() + forward_inputs = [inputs, labels] + + with tempfile.TemporaryDirectory() as temp_dir: + # Save models & checkpoint files to load them later. + checkpoint_file_path, model_file_path, optimizer_file_path = _get_test_models_path( + temp_dir, simple_model, onnx_model, optimizer_model=optimizer_model + ) + # Create Checkpoint State. + state = CheckpointState(checkpoint_file_path) + # Create a Module and Optimizer. + model = Module(model_file_path, state) + optimizer = Optimizer(optimizer_file_path, model) + scheduler = LinearLRScheduler(optimizer, 1, 2, 0.2) + + # Test get and set learning rate. + lr = optimizer.get_learning_rate() + assert np.allclose(lr, 0.0) + + model.train() + model(forward_inputs) + optimizer.step() + scheduler.step() + + # Get new learning rate. + new_lr = optimizer.get_learning_rate() + assert new_lr != lr + + def test_training_module_checkpoint(): # Initialize Models simple_model, onnx_model, _, _, _ = _create_training_models()