Skip to content

Commit

Permalink
fix: missing equal, initialize empty optimizer, and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Dec 9, 2024
1 parent d54ed30 commit b373818
Showing 1 changed file with 20 additions and 26 deletions.
46 changes: 20 additions & 26 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def __init__(
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self.optimizer = None
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Expand All @@ -457,7 +458,8 @@ def _wrap_model(self, model, training=True, dataloader=None):

def create_optimizer(self):
# For all other cases, use parent implementation
if (self.args.loraplus_lr_ratio is None
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
):
Expand All @@ -473,14 +475,12 @@ def create_optimizer(self):
)

if self.args.loraplus_lr_ratio is not None:
self.optimizer = (
create_loraplus_optimizer(
opt_model,
optimizer_cls,
loraplus_lr_ratio=self.args.loraplus_lr_ratio,
loraplus_lr_embedding=self.args.loraplus_lr_embedding,
**optimizer_kwargs,
)
self.optimizer = create_loraplus_optimizer(
opt_model,
optimizer_cls,
loraplus_lr_ratio=self.args.loraplus_lr_ratio,
loraplus_lr_embedding=self.args.loraplus_lr_embedding,
**optimizer_kwargs,
)
elif (
self.args.embedding_lr_scale is not None
Expand All @@ -493,18 +493,16 @@ def create_optimizer(self):
embedding_lr_scale=self.args.embedding_lr_scale,
embedding_lr=self.args.embedding_lr,
decay_parameters=decay_parameters,
weight_decay=self.args.weight_decay,
optimizer_cls=optimizer_cls,
optimizer_kwargs=optimizer_kwargs,
)

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
self.optimizer = smp.DistributedOptimizer(self.optimizer)

return self.optimizer


def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
Expand Down Expand Up @@ -984,9 +982,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
tag_names = ["axolotl", "dpo"]

def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
super().__init__(*args, **kwargs)

def create_optimizer(self):
# For all other cases, use parent implementation
Expand All @@ -999,20 +997,16 @@ def create_optimizer(self):
self.args,
opt_model,
)
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
create_loraplus_optimizer(
opt_model,
optimizer_cls,
loraplus_lr_ratio=self.args.loraplus_lr_ratio,
loraplus_lr_embedding=self.args.loraplus_lr_embedding,
**optimizer_kwargs,
)
self.optimizer = create_loraplus_optimizer(
opt_model,
optimizer_cls,
loraplus_lr_ratio=self.args.loraplus_lr_ratio,
loraplus_lr_embedding=self.args.loraplus_lr_embedding,
**optimizer_kwargs,
)

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
self.optimizer = smp.DistributedOptimizer(self.optimizer)

return self.optimizer

Expand Down Expand Up @@ -1780,7 +1774,7 @@ def build(self, total_num_steps):

optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer = "adopt_adamw":
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT

optimizer_cls = ADOPT
Expand Down

0 comments on commit b373818

Please sign in to comment.