Skip to content

Commit

Permalink
NeMo-UX: Add sgd optim (NVIDIA#11157)
Browse files Browse the repository at this point in the history
* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add sgd

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove stale args

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove clip grad arg

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Pass optimizer as a run.Partial instead

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* move param extraction out of PytorchOptimizerModule

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Add ParamsT import

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
2 people authored and HuiyingLi committed Nov 15, 2024
1 parent 3d33186 commit 228619e
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 69 deletions.
2 changes: 1 addition & 1 deletion examples/llm/peft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def formatting_prompts_func(examples):
use_distributed_sampler=use_dist_samp,
logger=wandb,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(max_lr=1e-5, clip_grad=0.5)),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=None,
peft=llm.peft.LoRA(
target_modules=['*_proj'],
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def squad(tokenizer) -> pl.LightningDataModule:
use_distributed_sampler=use_dist_samp,
logger=wandb,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(max_lr=1e-5, clip_grad=0.5)),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=None,
)

Expand Down
34 changes: 10 additions & 24 deletions nemo/collections/llm/recipes/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
import nemo_run as run
from megatron.core.optimizer import OptimizerConfig

from nemo.lightning.pytorch.optim import (
CosineAnnealingScheduler,
MegatronOptimizerModule,
OptimizerModule,
PytorchOptimizerModule,
)
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, MegatronOptimizerModule, PytorchOptimizerModule


@run.cli.factory
Expand All @@ -35,7 +30,7 @@ def distributed_fused_adam_with_cosine_annealing(
max_lr: float = 1e-4,
min_lr: Optional[float] = None,
clip_grad: float = 1.0,
) -> run.Config[OptimizerModule]:
) -> run.Config[PytorchOptimizerModule]:

opt_cfg = run.Config(
OptimizerConfig,
Expand Down Expand Up @@ -68,20 +63,17 @@ def distributed_fused_adam_with_cosine_annealing(

@run.cli.factory
def pytorch_adam_with_cosine_annealing(
precision: str = "bf16-mixed", # or "16-mixed"
warmup_steps: int = 2000,
constant_steps: int = 0,
max_lr: float = 1e-5,
min_lr: Optional[float] = None,
clip_grad: float = 1.0,
) -> run.Config[OptimizerModule]:
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import Adam

return run.Config(
PytorchOptimizerModule,
optim_cls=Adam,
config=run.Config(
dict,
optimizer_fn=run.Partial(
Adam,
lr=max_lr,
weight_decay=0.1,
betas=(0.9, 0.95),
Expand All @@ -98,21 +90,15 @@ def pytorch_adam_with_cosine_annealing(

@run.cli.factory
def pytorch_adam_with_flat_lr(
precision: str = "bf16-mixed", # or "16-mixed"
warmup_steps: int = 2000,
constant_steps: int = 0,
max_lr: float = 1e-5,
min_lr: Optional[float] = None,
clip_grad: float = 1.0,
) -> run.Config[OptimizerModule]:
lr: float = 1e-5,
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import Adam

return run.Config(
PytorchOptimizerModule,
optim_cls=Adam,
config=run.Config(
dict,
lr=max_lr,
optimizer_fn=run.Partial(
Adam,
lr=lr,
weight_decay=0.1,
betas=(0.9, 0.95),
eps=1e-8,
Expand Down
62 changes: 62 additions & 0 deletions nemo/collections/llm/recipes/optim/sgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import nemo_run as run

from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, PytorchOptimizerModule


@run.cli.factory
def pytorch_sgd_with_cosine_annealing(
warmup_steps: int = 2000,
constant_steps: int = 0,
max_lr: float = 1e-5,
min_lr: Optional[float] = None,
wd: float = 1e-4,
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import SGD

return run.Config(
PytorchOptimizerModule,
optimizer_fn=run.Partial(
SGD,
lr=max_lr,
weight_decay=wd,
),
lr_scheduler=run.Config(
CosineAnnealingScheduler,
warmup_steps=warmup_steps,
constant_steps=constant_steps,
min_lr=min_lr or (0.1 * max_lr),
),
)


@run.cli.factory
def pytorch_sgd_with_flat_lr(
lr: float = 1e-5,
wd: float = 1e-4,
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import SGD

return run.Config(
PytorchOptimizerModule,
optimizer_fn=run.Partial(
SGD,
lr=lr,
weight_decay=wd,
),
)
84 changes: 41 additions & 43 deletions nemo/lightning/pytorch/optim/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from typing import Callable, List, Optional

import pytorch_lightning as pl
import pytorch_lightning as L
from torch.optim import Optimizer
from torch.optim.optimizer import ParamsT

from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule
Expand All @@ -25,20 +27,43 @@ def _param_does_not_have_wd(param_name, param):
return 'bias' in param_name


def _extract_model_params_for_optim(model, weight_decay=0, no_weight_decay_cond=None):
params_with_wd, params_without_wd = [], []
if no_weight_decay_cond is not None:
for name, param in model.named_parameters():
if no_weight_decay_cond(name, param):
params_without_wd.append(param)
else:
params_with_wd.append(param)
else:
params_with_wd = model.parameters()

assert max(map(len, (params_with_wd, params_without_wd))) > 0, "Expected at least one optimizer with params"

return [
{'params': params, 'weight_decay': wd}
for params, wd in zip((params_with_wd, params_without_wd), (weight_decay, 0))
]


class PytorchOptimizerModule(OptimizerModule):
"""A OptimizerModule for pytorch optimizers.
Attributes:
config (OptimizerConfig): Configuration for the optimizer.
optimizer_fn (Callable[[ParamsT], Optimizer]): Configuration for the optimizer.
no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
lr_mult (float): Learning rate multiplier.
Example::
config = OptimizerConfig(...)
optimizer_fn = run.Partial(
SGD,
lr=lr,
weight_decay=wd,
)
lr_scheduler = MyLRSchedulerModule(...)
optimizer_module = PytorchOptimizerModule(config, lr_scheduler)
optimizer_module = PytorchOptimizerModule(optimizer_fn, lr_scheduler)
Methods:
setup(model): Sets up the optimizer.
Expand All @@ -47,8 +72,7 @@ class PytorchOptimizerModule(OptimizerModule):

def __init__(
self,
optim_cls,
config: dict = {'lr': 3e-4},
optimizer_fn: Callable[[ParamsT], Optimizer],
lr_scheduler: Optional[LRSchedulerModule] = None,
no_weight_decay_cond: Optional[Callable] = _param_does_not_have_wd,
scale_lr_cond: Optional[Callable] = None,
Expand All @@ -57,20 +81,18 @@ def __init__(
"""Initializes the PytorchOptimizerModule.
Args:
config (OptimizerConfig): Configuration for the optimizer.
optimizer_fn (Callable[[ParamsT], Optimizer]): Configuration for the optimizer.
lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module.
no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
lr_mult (float): Learning rate multiplier.
"""

super().__init__(lr_scheduler=lr_scheduler)
self.optim_cls = optim_cls
self.config = config
self.optimizer_fn = optimizer_fn
self.no_weight_decay_cond = no_weight_decay_cond
self.scale_lr_cond = scale_lr_cond
self.lr_mult = lr_mult
self.optim_cls = optim_cls

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
# Noop
Expand All @@ -92,41 +114,17 @@ def optimizers(self, model) -> List[Optimizer]:
if isinstance(model, MegatronParallel):
raise ValueError("Model cannot be an instance of MegatronParallel")

params_with_wd, params_without_wd = [], []
if self.no_weight_decay_cond is not None:
for name, param in model.named_parameters():
if self.no_weight_decay_cond(name, param):
params_without_wd.append(param)
else:
params_with_wd.append(param)
else:
params_with_wd = model.parameters()

optimizers = []
if len(params_with_wd) > 0:
optimizers.append(
self.optim_cls(
params_with_wd,
**self.config,
)
)

if len(params_without_wd) > 0:
wd = self.config.get('weight_decay', None)
kwargs['weight_decay'] = 0
optimizers.append(
self.optim_cls(
params_without_wd,
**kwargs,
)
)
# restore value
if wd is not None:
kwargs['weight_decay'] = wd

assert len(optimizers) > 0, "Expected at least one optimizer with params"
return optimizers
wd = self.optimizer_fn.keywords.get('weight_decay', 0)
return self.optimizer_fn(_extract_model_params_for_optim(model, wd, self.no_weight_decay_cond))

def finalize_model_grads(self, *args, **kwargs):
# Noop
pass

def connect(self, model: L.LightningModule) -> None:
"""Connects the optimizer module to the model and trainer.
Args:
model (L.LightningModule): The model to which the optimizer module is being connected.
"""
model.configure_optimizers = lambda: self.optimizers(model)

0 comments on commit 228619e

Please sign in to comment.