-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
expose lr scheduler python bindings for on device training. (#13882)
### Description Exposing LR Scheduler python bindings for on device training. Co-authored-by: Baiju Meswani <[email protected]>
- Loading branch information
1 parent
b999022
commit e49f358
Showing
5 changed files
with
119 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .checkpoint_state import CheckpointState | ||
from .lr_scheduler import LinearLRScheduler | ||
from .module import Module | ||
from .optimizer import Optimizer |
34 changes: 34 additions & 0 deletions
34
orttraining/orttraining/python/training/api/lr_scheduler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# lr_scheduler.py | ||
|
||
from onnxruntime.capi import _pybind_state as C | ||
|
||
|
||
class LinearLRScheduler: | ||
""" | ||
Linearly updates the learning rate in the optimizer | ||
The linear learning rate scheduler decays the learning rate by linearly updated | ||
multiplicative factor from the initial learning rate set on the training session to 0. The decay | ||
is performed after the initial warm up phase where the learning rate is linearly incremented | ||
from to the initial learning rate provided. | ||
Args: | ||
optimizer (:obj:`training_api.Optimizer`): User's onnxruntime training Optimizer | ||
warmup_step_count (int): The number of steps in the warm up phase. | ||
total_step_count (int): The total number of training steps. | ||
initial_lr (float): The initial learning rate. | ||
""" | ||
|
||
def __init__(self, optimizer, warmup_step_count, total_step_count, initial_lr) -> None: | ||
|
||
self._scheduler = C.LinearLRScheduler(optimizer._optimizer, warmup_step_count, total_step_count, initial_lr) | ||
|
||
def step(self): | ||
""" | ||
The step method of the LinearLRScheduler class is used to update the learning rate of the optimizer according | ||
to the scheduler's strategy. | ||
This method should be called at each step of training to ensure that the learning rate is properly adjusted. | ||
""" | ||
self._scheduler.scheduler_step() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters