-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT / Trainer: LOMO optimizer support #30178
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cc @muellerzr would love to have a first review! 🙏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This is a very interesting optimizer, so good to have this. Left some comments, I think it'd be best to downstream the support so accelerate can do .backward()
easily and then bring it up to here
src/transformers/trainer.py
Outdated
if not is_lomo_available(): | ||
raise ImportError( | ||
"You need to install `galore_torch` in order to use GaLore optimizers" | ||
" install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be updated to use their pypi package, galore-torch
src/transformers/trainer.py
Outdated
@@ -2111,6 +2135,8 @@ def _inner_training_loop( | |||
self._globalstep_last_logged = self.state.global_step | |||
model.zero_grad() | |||
grad_norm: Optional[float] = None | |||
# LOMO has a slightly different opitmizer API, see: https://github.com/OpenLMLab/LOMO/issues/73#issuecomment-2049612639 | |||
_is_lomo_optimizer = "Lomo" in self.optimizer.optimizer.__class__.__name__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if we can peek at self.optimizer.optimizer.__module__
maybe? (Or similar). Would be a bit more robust of a check than relying on namings
src/transformers/trainer.py
Outdated
@@ -2261,8 +2287,9 @@ def _inner_training_loop( | |||
else: | |||
grad_norm = _grad_norm | |||
|
|||
# Optimizer step | |||
self.optimizer.step() | |||
if not _is_lomo_optimizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep the comment here, and explain why lomo doesn't need the step
src/transformers/trainer.py
Outdated
if not is_lomo_optimizer: | ||
self.accelerator.backward(loss) | ||
else: | ||
self.optimizer.optimizer.fused_backward(loss, self._get_learning_rate()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me believe that we want to do this at the Accelerate
level, to be quite honest. Simply because accelerator.backward
handles grad accum dividing and a bunch of gradient scaling.
Let's discuss this offline for a bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah makes sense, happy to upstream that into accelerate to make things cleaner! ok will ping you offline
Co-authored-by: Zach Mueller <[email protected]>
thanks @muellerzr ! I offloaded most of the logic in huggingface/accelerate#2695 - wdyt ? 🙏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Very well done. cc @LysandreJik for a final review
src/transformers/trainer.py
Outdated
self._is_lomo_optimizer = is_lomo_available() and isinstance( | ||
_unwrap_optimizer(self.optimizer), (Lomo, AdaLomo) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can certainly do this, or just do optimizer.optimizer
(since we know it'll be wrapped by accelerate).
This is a bit safer so seem good to me :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work adding this!
Some general comments about the structure. Similar to badam, the fact parts of the code need to know which optimizers are being used indicates we might be drawing the wrong boundaries around our abstractions.
@@ -3225,7 +3282,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, | |||
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |||
scaled_loss.backward() | |||
else: | |||
self.accelerator.backward(loss) | |||
self.accelerator.backward(loss, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if we pass the learning rate through when lomo isn't being used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will break .. 😢 but we:
1- raise an error if users do not have the correct accelerate version with init-ing the trainer with lomo
2- pass learning_rate
only if the optimizer is a lomo optimizer
3- removed kwargs
in training step
So hopefully this should be safe enough 🙏
src/transformers/trainer.py
Outdated
_is_lomo = False | ||
|
||
if is_lomo_available(): | ||
from lomo_optim import AdaLomo, Lomo | ||
|
||
_is_lomo = isinstance(_unwrap_optimizer(self.optimizer), (Lomo, AdaLomo)) | ||
|
||
# For LOMO optimizers you need to explicitly use the learnign rate | ||
if _is_lomo: | ||
kwargs["learning_rate"] = self._get_learning_rate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have the optimizer set, don't we also have self._is_lomo_optimizer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes indeed ! changed that
src/transformers/trainer.py
Outdated
if isinstance(optimizer, AcceleratedOptimizer): | ||
optimizer = optimizer.optimizer | ||
return optimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it guaranteed to only ever be one level of wrapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/transformers/trainer.py
Outdated
if is_lomo_available(): | ||
from lomo_optim import AdaLomo, Lomo | ||
|
||
_is_lomo = isinstance(_unwrap_optimizer(self.optimizer), (Lomo, AdaLomo)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd move out the common logic to something like _is_lomo_optimizer
from here and L1086, which handles importing Lomo and AdaLomo and unwrapping the optimizer
src/transformers/trainer.py
Outdated
|
||
Return: | ||
`torch.Tensor`: The tensor with training loss on this batch. | ||
""" | ||
model.train() | ||
inputs = self._prepare_inputs(inputs) | ||
_is_lomo = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmmm..... needing to have this in the training_step is a good indication the abstractions here are leaky. Once we have the optimizer created, we shouldn't really need to know what type of optimizer it is in the rest of the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch .. i think it was an old code, now should be much cleaner !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for all the work adding this feature + tests.
I just have one comment about the self._is_lomo_optimizer
flag. LMKWYT
src/transformers/trainer.py
Outdated
@@ -398,6 +410,7 @@ def __init__( | |||
self.hp_name = None | |||
self.deepspeed = None | |||
self.is_in_train = False | |||
self._is_lomo_optimizer = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about this more, could we do something like self._optimizer_type
instead? Which then maps to e.g. enum values? This way, if new optimizers are added which require special logic, we just need one instance attribute
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, in that case what about just using args.optim
as they can be used directly as enums:
transformers/src/transformers/trainer.py
Line 1174 in 15c74a2
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDDYT of 9d547be ? With that I also removed the _unwrap_optimizer
logic which makes some part of the code cleaner!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks for iterating on this!
_ = trainer.train() | ||
|
||
for name, param in tiny_llama.named_parameters(): | ||
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tolerance is super small, do we expect optimizers to make changes on this order?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is ok to put it higher, I decided to put it low so that even small changes would be captured by the test (sometimes higher tolerances would fail even though the weights are properly updated + with a high learning rate, so just to be on the safe zone)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice - thanks for all the work adding this, including tests and iterating on the solution!
Thanks so much for the extensive review and help @amyeroberts @muellerzr ! |
What does this PR do?
Fixes: #29649
As requested by the community, this PR integrates LOMO optimizers into HF Trainer
https://github.com/OpenLMLab/LOMO
I am facing some issues with respect to AdaLOMO, which seems to have DS as a hard requirement as in the optimizers init: https://github.com/OpenLMLab/LOMO/blob/85d8105c48cbd676dbf6915ee755461cd241da9b/lomo_optim/adalomo.py#L85 - so leaving this as a draft for now until I try to figure out this issue
cc @amyeroberts