Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling gradient checkpointing in eval() mode #9878

Merged
merged 7 commits into from
Nov 8, 2024

Conversation

MikeTkachuk
Copy link
Contributor

Removed unnecessary if self.training ... check when using gradient checkpointing
#9850

@MikeTkachuk
Copy link
Contributor Author

Since all of the module implementations used the same if self.training and self.gradient_checkpointing: clause, I wonder if there is some reference template for def forward that others are using? Need to make sure future implementations are using the fixed clause

@@ -452,7 +452,7 @@ def forward(

# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh thanks! why do we also removed the torch.is_grad_enabled() check? gradient checkpointing isn't meaningful without gradient being computed, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added it back, thanks for pointing it out.
it does not break anything, but found that it throws an annoying warning when use_reentrant=True,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but found that it throws an annoying warning when use_reentrant=True,

what do you mean by that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_reentrant is an argument passed to torch.utils.checkpoint.checkpoint

if True one of the checks will print this to stderr
warnings.warn(
"None of the inputs have requires_grad=True. Gradients will be None"
)
but diffusers are using use_reentrant=False anyway

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh got thanks, so the warning is specific to when we use gradient checkpointing when gradient is not enabled

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 7, 2024

hi @MikeTkachuk unfortunately we have to update the branch and resolve conflicts now...
would you be able to do that? I'm cool with the change otherwise and can merge once it is synced with main

…le_grckpt_in_eval

# Conflicts:
#	src/diffusers/models/controlnet_flux.py
#	src/diffusers/models/controlnet_sd3.py
@MikeTkachuk
Copy link
Contributor Author

done, also fixed in a few other places I missed

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 8, 2024

can you rebase again?
I saw a bunch of changes that are not made in this PR and the commit history includes commits that's already been merged into main (from another PR)

@MikeTkachuk
Copy link
Contributor Author

fixed

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu yiyixuxu merged commit 5b972fb into huggingface:main Nov 8, 2024
15 checks passed
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants