Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Add LOMO optimizer #2695

Merged
merged 8 commits into from
May 3, 2024
Merged

FEAT: Add LOMO optimizer #2695

merged 8 commits into from
May 3, 2024

Conversation

younesbelkada
Copy link
Contributor

WIP - on par with huggingface/transformers#30178

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very good start, left some design comments that are more in-tune to what I'd expect for Accelerate :)

src/accelerate/accelerator.py Outdated Show resolved Hide resolved
Comment on lines 1958 to 1959
if is_lomo_available():
from lomo_optim import AdaLomo, Lomo
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done at the top of the file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing that leads to an error due to circular imports 😢 this is because lomo_optim imports stuff from transformers, that itself import stuff from accelerate. Added a comment there

Comment on lines 2044 to 2045
elif learning_rate is not None and self._has_lomo_optimizer:
self._lomo_backward(loss, learning_rate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to interact with self.scaler for AMP? That may be important!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm not sure here actually 🤯 I need to investigate a bit ..

Comment on lines +137 to +138
if is_lomo_available():
from lomo_optim import AdaLomo, Lomo
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing as earlier, but let's move the detection for Lomo to happen in here rather than in the Accelerator. Then we can make an attribute self.is_lomo_optimizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -2031,6 +2041,8 @@ def backward(self, loss, **kwargs):
return
elif self.scaler is not None:
self.scaler.scale(loss).backward(**kwargs)
elif learning_rate is not None and self._has_lomo_optimizer:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, if we have more optimizers needing this, we may want to abstract this to a backward_func at some point

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes makes sense!

@@ -2031,6 +2041,8 @@ def backward(self, loss, **kwargs):
return
elif self.scaler is not None:
self.scaler.scale(loss).backward(**kwargs)
elif learning_rate is not None and self._has_lomo_optimizer:
self._lomo_backward(loss, learning_rate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this lomo_backwards instead of hiding it.

Also, I'd rather see us raise an error if a learning rate isn't part of kwargs, and this is True, rather than silently failing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes totally sense, refactored a bit things in 3a518dc - LMK what do you think !

Comment on lines 3299 to 3302
if is_lomo_available():
from lomo_optim import AdaLomo, Lomo
else:
raise ValueError("`lomo_optim` package is needed to call backward on LOMO optimizers")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should already be caught in the prepare_optimizer portion, so it should be fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, removed it !

@younesbelkada younesbelkada marked this pull request as ready for review May 3, 2024 08:27
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is looking fantastic! Last bit I’d like is a quick example in examples/ using LOMO, but for the sake of the release today we can get this in.

@younesbelkada
Copy link
Contributor Author

Nice, thanks so much ! OK I will work on an example that use accelerate + LOMO and open a follow up PR !

@younesbelkada younesbelkada merged commit 6ac27e2 into main May 3, 2024
25 checks passed
@younesbelkada younesbelkada deleted the add-lomo branch May 3, 2024 08:55
# transformers & accelerate
from lomo_optim import AdaLomo, Lomo

self.has_lomo_optimizer = isinstance(optimizer, (Lomo, AdaLomo))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, we assume that the accelerator could have multiple self._optimizers, with some of them LOMO and others not. Would that not create the issue that self.has_lomo_optimizer takes the value based on whatever the last optimizer is? Would we not have to set: self.has_lomo_optimizer |= isinstance(optimizer, (Lomo, AdaLomo))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants