Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Nov 5, 2024
1 parent 030076e commit 9752704
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
7 changes: 3 additions & 4 deletions nemo/collections/llm/recipes/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from nemo.lightning.pytorch.optim import (
CosineAnnealingScheduler,
MegatronOptimizerModule,
OptimizerModule,
PytorchOptimizerModule,
)

Expand All @@ -35,7 +34,7 @@ def distributed_fused_adam_with_cosine_annealing(
max_lr: float = 1e-4,
min_lr: Optional[float] = None,
clip_grad: float = 1.0,
) -> run.Config[OptimizerModule]:
) -> run.Config[PytorchOptimizerModule]:

opt_cfg = run.Config(
OptimizerConfig,
Expand Down Expand Up @@ -72,7 +71,7 @@ def pytorch_adam_with_cosine_annealing(
constant_steps: int = 0,
max_lr: float = 1e-5,
min_lr: Optional[float] = None,
) -> run.Config[OptimizerModule]:
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import Adam

return run.Config(
Expand All @@ -96,7 +95,7 @@ def pytorch_adam_with_cosine_annealing(
@run.cli.factory
def pytorch_adam_with_flat_lr(
lr: float = 1e-5,
) -> run.Config[OptimizerModule]:
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import Adam

return run.Config(
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/llm/recipes/optim/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import nemo_run as run

from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, OptimizerModule, PytorchOptimizerModule
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, PytorchOptimizerModule


@run.cli.factory
Expand All @@ -26,7 +26,7 @@ def pytorch_sgd_with_cosine_annealing(
max_lr: float = 1e-5,
min_lr: Optional[float] = None,
wd: float = 1e-4,
) -> run.Config[OptimizerModule]:
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import SGD

return run.Config(
Expand All @@ -49,7 +49,7 @@ def pytorch_sgd_with_cosine_annealing(
def pytorch_sgd_with_flat_lr(
lr: float = 1e-5,
wd: float = 1e-4,
) -> run.Config[OptimizerModule]:
) -> run.Config[PytorchOptimizerModule]:
from torch.optim import SGD

return run.Config(
Expand Down

0 comments on commit 9752704

Please sign in to comment.