-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling gradient checkpointing in eval() mode #9878
Conversation
Since all of the module implementations used the same |
@@ -452,7 +452,7 @@ def forward( | |||
|
|||
# 3. Transformer blocks | |||
for i, block in enumerate(self.transformer_blocks): | |||
if self.training and self.gradient_checkpointing: | |||
if self.gradient_checkpointing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh thanks! why do we also removed the torch.is_grad_enabled()
check? gradient checkpointing isn't meaningful without gradient being computed, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added it back, thanks for pointing it out.
it does not break anything, but found that it throws an annoying warning when use_reentrant=True,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but found that it throws an annoying warning when use_reentrant=True,
what do you mean by that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_reentrant is an argument passed to torch.utils.checkpoint.checkpoint
if True one of the checks will print this to stderr
warnings.warn(
"None of the inputs have requires_grad=True. Gradients will be None"
)
but diffusers are using use_reentrant=False anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh got thanks, so the warning is specific to when we use gradient checkpointing when gradient is not enabled
hi @MikeTkachuk unfortunately we have to update the branch and resolve conflicts now... |
…le_grckpt_in_eval # Conflicts: # src/diffusers/models/controlnet_flux.py # src/diffusers/models/controlnet_sd3.py
done, also fixed in a few other places I missed |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
can you rebase again? |
fixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
Removed unnecessary
if self.training ...
check when using gradient checkpointing#9850