diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e511b6df86..4e55691857 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -437,6 +437,7 @@ def __init__( self.bench_data_collator = bench_data_collator self.eval_data_collator = eval_data_collator self.dataset_tags = dataset_tags + self.optimizer = None super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -457,7 +458,8 @@ def _wrap_model(self, model, training=True, dataloader=None): def create_optimizer(self): # For all other cases, use parent implementation - if (self.args.loraplus_lr_ratio is None + if ( + self.args.loraplus_lr_ratio is None and self.args.embedding_lr_scale is None and self.args.embedding_lr is None ): @@ -473,14 +475,12 @@ def create_optimizer(self): ) if self.args.loraplus_lr_ratio is not None: - self.optimizer = ( - create_loraplus_optimizer( - opt_model, - optimizer_cls, - loraplus_lr_ratio=self.args.loraplus_lr_ratio, - loraplus_lr_embedding=self.args.loraplus_lr_embedding, - **optimizer_kwargs, - ) + self.optimizer = create_loraplus_optimizer( + opt_model, + optimizer_cls, + loraplus_lr_ratio=self.args.loraplus_lr_ratio, + loraplus_lr_embedding=self.args.loraplus_lr_embedding, + **optimizer_kwargs, ) elif ( self.args.embedding_lr_scale is not None @@ -493,18 +493,16 @@ def create_optimizer(self): embedding_lr_scale=self.args.embedding_lr_scale, embedding_lr=self.args.embedding_lr, decay_parameters=decay_parameters, + weight_decay=self.args.weight_decay, optimizer_cls=optimizer_cls, optimizer_kwargs=optimizer_kwargs, ) if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) + self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.args.sample_packing and not self.args.pretraining: if self.args.multipack_real_batches: @@ -984,9 +982,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): tag_names = ["axolotl", "dpo"] def __init__(self, *args, dataset_tags=None, **kwargs): - super().__init__(*args, **kwargs) self.dataset_tags = dataset_tags self.optimizer = None + super().__init__(*args, **kwargs) def create_optimizer(self): # For all other cases, use parent implementation @@ -999,20 +997,16 @@ def create_optimizer(self): self.args, opt_model, ) - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - create_loraplus_optimizer( - opt_model, - optimizer_cls, - loraplus_lr_ratio=self.args.loraplus_lr_ratio, - loraplus_lr_embedding=self.args.loraplus_lr_embedding, - **optimizer_kwargs, - ) + self.optimizer = create_loraplus_optimizer( + opt_model, + optimizer_cls, + loraplus_lr_ratio=self.args.loraplus_lr_ratio, + loraplus_lr_embedding=self.args.loraplus_lr_embedding, + **optimizer_kwargs, ) if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) + self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer @@ -1780,7 +1774,7 @@ def build(self, total_num_steps): optimizer_cls = AdamWFp8 optimizer_kwargs.update(adam_kwargs) - elif self.cfg.optimizer = "adopt_adamw": + elif self.cfg.optimizer == "adopt_adamw": from axolotl.utils.optimizers.adopt import ADOPT optimizer_cls = ADOPT