Skip to content

Commit

Permalink
set max_steps for lr decay through config (#5780)
Browse files Browse the repository at this point in the history
* set max_steps for lr decay through config

* added warning for optim sched max_steps config option

* reverted changes to modelPT and updated megatron_base_model

* added the experimental cosine annealing scheduler class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update decay_steps for consine annealing exp class

* added copyright

---------

Co-authored-by: ANMOL GUPTA <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
4 people authored Feb 1, 2023
1 parent 9cf0d64 commit 589ccb3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.


from nemo.collections.nlp.parts.megatron_lr_schedulers import CosineAnnealingExp
from nemo.collections.nlp.parts.utils_funcs import list2str, tensor2list
32 changes: 32 additions & 0 deletions nemo/collections/nlp/parts/megatron_lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.core.optim.lr_scheduler import AVAILABLE_SCHEDULERS, CosineAnnealing


class CosineAnnealingExp(CosineAnnealing):
"""
Setting max_steps_for_lr_sched for this scheduler in the config is experimental and "
not recommended. The scheduler can use max_steps automatically from "
trainer.max_steps.
"""

def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, max_steps_for_lr_sched=None, **kwargs):
super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
if max_steps_for_lr_sched:
self.max_steps = max_steps_for_lr_sched
self.decay_steps = self.max_steps - (self.constant_steps + self.warmup_steps)


AVAILABLE_SCHEDULERS['CosineAnnealingExp'] = CosineAnnealingExp

0 comments on commit 589ccb3

Please sign in to comment.