-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
schedulefree optimizers #30079
schedulefree optimizers #30079
Conversation
FYI this will need huggingface/accelerate#2631 as we need to upstream accelerate's ability to call train/eval on a wrapped optimizer |
Some thoughts:
I'm just a little bit reserved for now since the author themselves aren't providing any transformer benchmarks, nor have they compared their CNN baselines to superconvergence, which is the goto standard for fast training for CNNs. Likewise https://parameterfree.com/2023/08/30/yet-another-icml-award-fiasco/ wasn't pleasant. |
Should be very easy to test this on Phi-2 or TinyLlama when the implementation works? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @winglian ! 🤩 I left one minor comment, wdyt?
src/transformers/trainer.py
Outdated
@@ -3117,6 +3145,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, | |||
`torch.Tensor`: The tensor with training loss on this batch. | |||
""" | |||
model.train() | |||
if "ScheduleFree" in self.optimizer.__class__.__name__: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe instead of checking the class name here we could inject an attribute _hf_schedule_free_optim
to make sure we can support that in the future for other shcedule free optimizers, what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that would be on the Trainer class, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the place that makes the most sense to set that would be in get_optimizer_cls_and_kwargs
but that is a @staticmethod
so has no access to the trainer object. We could do something along the lines of
setattr(self.optimizer, "_hf_schedule_free_optim", True)
after we instantiate the optimizer_cls
but we would still have to do some sort of class name detection.
Alternatively we could pass another value in the return tuple specific to schedule_free optimizers (but that feels worse)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh good point yeah, in that case this is probably already fine I would say, thanks for investigating @winglian !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than have it as a stateful attribute, could we instead move this logic out to a module-level function e.g.:
def _is_schedule_free_optimizer(optimizer):
return "ScheduleFree" in optimizer__class__.__name__
?
This way:
- The check is a bit more explicit within the code logic
- we can easily adapt the checking in one place, rather than throughout the code, if we end up introducing e.g. a
_is_schedule_free
attribute or there's schedule free optimizers with slightly different names
This PR should maybe also add a few lines to the README about "how to use this". |
We've merged the accelerate portion in, so if anyone is trying this out in distributed fashions, you can do |
There is any chance of this making into the main branch? I and other confirmed that the results are real. Thank you @winglian |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super useful addition of scheduler free optimizers @winglian! It would be great to document the usage along with a minimal example.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Is their any remaining work I could contribute towards getting this PR merged? Cheers |
87b9651
to
9e654bb
Compare
@pacman100 @muellerzr @younesbelkada Can we get a new review to get this merged? Since the last check, I rebased, added some fixes and docs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Overall LG2M, let's pin schedulefree as a >=
however.
Can you also run the quality checks? Afterwords at least from my end looks good to merge.
@muellerzr ran the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a bunch! cc @LysandreJik for final review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
Main comment is about the getattr
logic in get_optimizer_cls_and_kwargs
src/transformers/trainer.py
Outdated
@@ -3117,6 +3145,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, | |||
`torch.Tensor`: The tensor with training loss on this batch. | |||
""" | |||
model.train() | |||
if "ScheduleFree" in self.optimizer.__class__.__name__: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than have it as a stateful attribute, could we instead move this logic out to a module-level function e.g.:
def _is_schedule_free_optimizer(optimizer):
return "ScheduleFree" in optimizer__class__.__name__
?
This way:
- The check is a bit more explicit within the code logic
- we can easily adapt the checking in one place, rather than throughout the code, if we end up introducing e.g. a
_is_schedule_free
attribute or there's schedule free optimizers with slightly different names
src/transformers/trainer.py
Outdated
additional_optim_kwargs["warmup_steps"] = args.warmup_steps | ||
additional_optim_kwargs.update( | ||
{ | ||
"weight_lr_power": float(getattr(torch, optim_args.get("weight_lr_power", 2.0))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem right:
- If we get
"weight_lr_power"
fromoptim_args
I'm presuming it's a float as string e.g."2.0"
? I don't think torch.2.0 exists? - If
optim_args
doesn't have"weight_lr_power"
, then the second argument togetattr
is a float, which isn't compatible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"weight_lr_power": float(getattr(torch, optim_args.get("weight_lr_power", 2.0))), | |
"weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)), |
src/transformers/trainer.py
Outdated
additional_optim_kwargs.update( | ||
{ | ||
"weight_lr_power": float(getattr(torch, optim_args.get("weight_lr_power", 2.0))), | ||
"r": float(getattr(torch, optim_args.get("r", 0.0))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"r": float(getattr(torch, optim_args.get("r", 0.0))), | |
"r": float(optim_args.get("r", 0.0)), |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Will get back to this soon. Not stale 😅 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@winglian please don´t let it die |
7f5516e
to
2361182
Compare
@amyeroberts I addressed your comments. LMK what else is required to push this through! |
src/transformers/trainer.py
Outdated
if _is_schedule_free_optimizer(self.optimizer): | ||
self.optimizer.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need to have optimizer specific logic in the main logic loops: this makes our training logic hard to handle and will become quickly too cluttered if many optimizers have to have specific logic.
cc @muellerzr who's been working on refactoring a lot of similar logic to this. Ideally all optimizers would use the same API and would support calling .train
or .eval
on the class, if this is required, but this is a large piece of work. Is this customization as-is acceptable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be preferable to check if the optimizer has the eval/train methods and call them, instead of checking the class name?
cc @adefazio
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the way I do it in my code is by checking for the eval/train methods, i.e. duck typing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this approach acceptable?
if _is_schedule_free_optimizer(self.optimizer): | |
self.optimizer.train() | |
if hasattr(self.optimizer, 'train') and callable(getattr(self.optimizer, 'train')): | |
self.optimizer.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think this would be preferable as it's more extensible to other optimizers, and then puts the responsibility for management on implementation / checking on the optimizer implementation, rather than us within trainer
2361182
to
5964f5d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating - looks good to me!
cc @muellerzr to confirm the conditional check on the eval
and train
methods are OK
@winglian It looks like we'll need a bit more iteration on the |
looks like we need to revisit huggingface/accelerate#2631 which i've done in huggingface/accelerate#3055
👍 |
b2b7725
to
179232c
Compare
@amyeroberts can we kick off the tests again? does the build pick up the latest accelerate release automatically? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a request for one more additional check for users who don't have the right minimum accelerate version. cc @LysandreJik for final 🤗
Co-authored-by: Aman Gupta Karmani <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @winglian!
LGTM
If I use this, can I ignore the lr that gets printed out by trainer? (looks like the default linear decay) Or do I need to change the trainer's lr scheduler to constant? |
* schedulefree optimizers * fix train instead of eval for optimizer * fixes and update docs * chore: lint * add tests and drop overly-verbose _32bit suffix * chore: lint * fix for docs * fix code review issues * use duck-typing to avoid per-optimizer patches * fixup style * fixup style * warn if incorrect accelerate version with schedule free Co-authored-by: Aman Gupta Karmani <[email protected]> --------- Co-authored-by: Aman Karmani <[email protected]>
What does this PR do?
integrates meta's https://github.com/facebookresearch/schedule_free for adamw & sgd
https://twitter.com/aaron_defazio/status/1776320004465582331
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr @younesbelkada @pacman100