Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core / modeling] Fix training bug with PEFT + GC #28031

Merged
merged 1 commit into from
Dec 14, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Dec 14, 2023

What does this PR do?

Fixes #28023

4.36.0 led to a bug when users are in the case of GC + training which should force-set use_cache to False here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1008 - which is force-set to True during the backward pass for some reason, only in the case where one uses PEFT + GC.

The fix is to force-set use_cache to False before computing past_key_value_length here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1042

cc @amyeroberts

@younesbelkada younesbelkada changed the title [core / modeling] Fix training bug with new cache refactor [core / modeling] Fix training bug with PEFT + GC Dec 14, 2023
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@@ -578,6 +578,13 @@ def forward(
seq_length_with_past = seq_length
past_key_values_length = 0

if self.gradient_checkpointing and self.training:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why moving the logic here should make a difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very unusual but apparently when using GC + training the backward pass calls again the forward pass but with a non-None past_key_value.
I think use_cache is set to True by default on all configs; therefore avoiding the block

        if use_cache:
            use_legacy_cache = not isinstance(past_key_values, Cache)
            if use_legacy_cache:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            past_key_values_length = past_key_values.get_usable_length(seq_length)

to be called seems to be the right fix. I will add tests directly on PEFT as it is related to PEFT

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to also get why we have to do this (why the forward re-runs) but static cache might actually do the same

@younesbelkada younesbelkada merged commit 73de510 into huggingface:main Dec 14, 2023
19 checks passed
@younesbelkada younesbelkada deleted the fix-training-bug branch December 14, 2023 11:20
iantbutler01 pushed a commit to BismuthCloud/transformers that referenced this pull request Dec 16, 2023
amyeroberts pushed a commit that referenced this pull request Dec 18, 2023
staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
@DailyCasual
Copy link

I am having this issue with Kohya Dreammaker as well. Downgrading to 4.27.1 has seemed to help. How do I go about actually fixing this issue? I have all the pre reqs installed and my models work, but I have no idea how to use python to fix this issue?

@amyeroberts
Copy link
Collaborator

@DailyCasual Could you open a new issue, detailing the error encountered, what you've experimented e.g. versions and their behaviour? This helps us better track what bugs are new and have been resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PEFT+gradient checkpointing causes attention mask shape mismatch during backward pass
4 participants