Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NeMo-UX: Add sgd optim #11157

Merged
merged 12 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/llm/peft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def formatting_prompts_func(examples):
use_distributed_sampler=use_dist_samp,
logger=wandb,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(max_lr=1e-5, clip_grad=0.5)),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=None,
peft=llm.peft.LoRA(
target_modules=['*_proj'],
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def squad(tokenizer) -> pl.LightningDataModule:
use_distributed_sampler=use_dist_samp,
logger=wandb,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(max_lr=1e-5, clip_grad=0.5)),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=None,
)

Expand Down
34 changes: 10 additions & 24 deletions nemo/collections/llm/recipes/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
import nemo_run as run
from megatron.core.optimizer import OptimizerConfig

from nemo.lightning.pytorch.optim import (
CosineAnnealingScheduler,
MegatronOptimizerModule,
OptimizerModule,
PytorchOptimizerModule,
)
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, MegatronOptimizerModule, PytorchOptimizerModule


@run.cli.factory
Expand All @@ -35,7 +30,7 @@ def distributed_fused_adam_with_cosine_annealing(
max_lr: float = 1e-4,
min_lr: Optional[float] = None,
clip_grad: float = 1.0,
) -> run.Config[OptimizerModule]:
) -> run.Config[PytorchOptimizerModule]:

opt_cfg = run.Config(
OptimizerConfig,
Expand Down Expand Up @@ -68,20 +63,17 @@ def distributed_fused_adam_with_cosine_annealing(

@run.cli.factory
def pytorch_adam_with_cosine_annealing(
precision: str = "bf16-mixed", # or "16-mixed"
warmup_steps: int = 2000,
constant_steps: int = 0,
max_lr: float = 1e-5,
min_lr: Optional[float] = None,
clip_grad: float = 1.0,
) -> run.Config[OptimizerModule]:
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import Adam

return run.Config(
PytorchOptimizerModule,
optim_cls=Adam,
config=run.Config(
dict,
optimizer_fn=run.Partial(
Adam,
lr=max_lr,
weight_decay=0.1,
betas=(0.9, 0.95),
Expand All @@ -98,21 +90,15 @@ def pytorch_adam_with_cosine_annealing(

@run.cli.factory
def pytorch_adam_with_flat_lr(
precision: str = "bf16-mixed", # or "16-mixed"
warmup_steps: int = 2000,
constant_steps: int = 0,
max_lr: float = 1e-5,
min_lr: Optional[float] = None,
clip_grad: float = 1.0,
) -> run.Config[OptimizerModule]:
lr: float = 1e-5,
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import Adam

return run.Config(
PytorchOptimizerModule,
optim_cls=Adam,
config=run.Config(
dict,
lr=max_lr,
optimizer_fn=run.Partial(
Adam,
lr=lr,
weight_decay=0.1,
betas=(0.9, 0.95),
eps=1e-8,
Expand Down
62 changes: 62 additions & 0 deletions nemo/collections/llm/recipes/optim/sgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import nemo_run as run

from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, PytorchOptimizerModule


@run.cli.factory
def pytorch_sgd_with_cosine_annealing(
warmup_steps: int = 2000,
constant_steps: int = 0,
max_lr: float = 1e-5,
min_lr: Optional[float] = None,
wd: float = 1e-4,
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import SGD

return run.Config(
PytorchOptimizerModule,
optimizer_fn=run.Partial(
SGD,
lr=max_lr,
weight_decay=wd,
),
lr_scheduler=run.Config(
CosineAnnealingScheduler,
warmup_steps=warmup_steps,
constant_steps=constant_steps,
min_lr=min_lr or (0.1 * max_lr),
),
)


@run.cli.factory
def pytorch_sgd_with_flat_lr(
lr: float = 1e-5,
wd: float = 1e-4,
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import SGD

return run.Config(
PytorchOptimizerModule,
optimizer_fn=run.Partial(
SGD,
lr=lr,
weight_decay=wd,
),
)
84 changes: 41 additions & 43 deletions nemo/lightning/pytorch/optim/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from typing import Callable, List, Optional

import pytorch_lightning as pl
import pytorch_lightning as L
from torch.optim import Optimizer
from torch.optim.optimizer import ParamsT

from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule
Expand All @@ -25,20 +27,43 @@ def _param_does_not_have_wd(param_name, param):
return 'bias' in param_name


def _extract_model_params_for_optim(model, weight_decay=0, no_weight_decay_cond=None):
params_with_wd, params_without_wd = [], []
if no_weight_decay_cond is not None:
for name, param in model.named_parameters():
if no_weight_decay_cond(name, param):
params_without_wd.append(param)
else:
params_with_wd.append(param)
else:
params_with_wd = model.parameters()

assert max(map(len, (params_with_wd, params_without_wd))) > 0, "Expected at least one optimizer with params"

return [
{'params': params, 'weight_decay': wd}
for params, wd in zip((params_with_wd, params_without_wd), (weight_decay, 0))
]


class PytorchOptimizerModule(OptimizerModule):
"""A OptimizerModule for pytorch optimizers.

Attributes:
config (OptimizerConfig): Configuration for the optimizer.
optimizer_fn (Callable[[ParamsT], Optimizer]): Configuration for the optimizer.
no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
lr_mult (float): Learning rate multiplier.

Example::

config = OptimizerConfig(...)
optimizer_fn = run.Partial(
SGD,
lr=lr,
weight_decay=wd,
)
lr_scheduler = MyLRSchedulerModule(...)
optimizer_module = PytorchOptimizerModule(config, lr_scheduler)
optimizer_module = PytorchOptimizerModule(optimizer_fn, lr_scheduler)

Methods:
setup(model): Sets up the optimizer.
Expand All @@ -47,8 +72,7 @@ class PytorchOptimizerModule(OptimizerModule):

def __init__(
self,
optim_cls,
config: dict = {'lr': 3e-4},
optimizer_fn: Callable[[ParamsT], Optimizer],
lr_scheduler: Optional[LRSchedulerModule] = None,
no_weight_decay_cond: Optional[Callable] = _param_does_not_have_wd,
scale_lr_cond: Optional[Callable] = None,
Expand All @@ -57,20 +81,18 @@ def __init__(
"""Initializes the PytorchOptimizerModule.

Args:
config (OptimizerConfig): Configuration for the optimizer.
optimizer_fn (Callable[[ParamsT], Optimizer]): Configuration for the optimizer.
lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module.
no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
lr_mult (float): Learning rate multiplier.
"""

super().__init__(lr_scheduler=lr_scheduler)
self.optim_cls = optim_cls
self.config = config
self.optimizer_fn = optimizer_fn
self.no_weight_decay_cond = no_weight_decay_cond
self.scale_lr_cond = scale_lr_cond
self.lr_mult = lr_mult
self.optim_cls = optim_cls

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
# Noop
Expand All @@ -92,41 +114,17 @@ def optimizers(self, model) -> List[Optimizer]:
if isinstance(model, MegatronParallel):
raise ValueError("Model cannot be an instance of MegatronParallel")

params_with_wd, params_without_wd = [], []
if self.no_weight_decay_cond is not None:
for name, param in model.named_parameters():
if self.no_weight_decay_cond(name, param):
params_without_wd.append(param)
else:
params_with_wd.append(param)
else:
params_with_wd = model.parameters()

optimizers = []
if len(params_with_wd) > 0:
optimizers.append(
self.optim_cls(
params_with_wd,
**self.config,
)
)

if len(params_without_wd) > 0:
wd = self.config.get('weight_decay', None)
kwargs['weight_decay'] = 0
optimizers.append(
self.optim_cls(
params_without_wd,
**kwargs,
)
)
# restore value
if wd is not None:
kwargs['weight_decay'] = wd

assert len(optimizers) > 0, "Expected at least one optimizer with params"
return optimizers
wd = self.optimizer_fn.keywords.get('weight_decay', 0)
return self.optimizer_fn(_extract_model_params_for_optim(model, wd, self.no_weight_decay_cond))

def finalize_model_grads(self, *args, **kwargs):
# Noop
pass

def connect(self, model: L.LightningModule) -> None:
"""Connects the optimizer module to the model and trainer.

Args:
model (L.LightningModule): The model to which the optimizer module is being connected.
"""
model.configure_optimizers = lambda: self.optimizers(model)
Loading