-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
/ modeling
] Fix training bug with PEFT + GC
#28031
[core
/ modeling
] Fix training bug with PEFT + GC
#28031
Conversation
core
/ modeling
] Fix training bug with new cache refactorcore
/ modeling
] Fix training bug with PEFT + GC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
@@ -578,6 +578,13 @@ def forward( | |||
seq_length_with_past = seq_length | |||
past_key_values_length = 0 | |||
|
|||
if self.gradient_checkpointing and self.training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why moving the logic here should make a difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is very unusual but apparently when using GC + training the backward pass calls again the forward pass but with a non-None
past_key_value
.
I think use_cache
is set to True
by default on all configs; therefore avoiding the block
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
to be called seems to be the right fix. I will add tests directly on PEFT as it is related to PEFT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to also get why we have to do this (why the forward re-runs) but static cache might actually do the same
I am having this issue with Kohya Dreammaker as well. Downgrading to 4.27.1 has seemed to help. How do I go about actually fixing this issue? I have all the pre reqs installed and my models work, but I have no idea how to use python to fix this issue? |
@DailyCasual Could you open a new issue, detailing the error encountered, what you've experimented e.g. versions and their behaviour? This helps us better track what bugs are new and have been resolved. |
What does this PR do?
Fixes #28023
4.36.0 led to a bug when users are in the case of GC + training which should force-set
use_cache
toFalse
here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1008 - which is force-set toTrue
during the backward pass for some reason, only in the case where one uses PEFT + GC.The fix is to force-set
use_cache
toFalse
before computingpast_key_value_length
here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1042cc @amyeroberts