-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a static cache that offloads to the CPU or other device #32161
Add a static cache that offloads to the CPU or other device #32161
Conversation
87541c5
to
8b862ab
Compare
386f231
to
3d413cb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already looks great! IMO would be nice to add a snippet of how to use it in the doc of the class + once #32150 is merged also add this as compatible with torch.compile
(while the non static version won't be)
Really like the findings you posted on the issue, will be useful for everyone I think!
Also would be interesting to test this / showcase potential for huge beamsearch ! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very very cool! 🔥
As Arthur wrote, this PR is now missing the complementary parts:
- An example in the docstring
- Some benchmarks in the PR, for future reference
- An integration test to prevent regressions (like the ones here)
src/transformers/cache_utils.py
Outdated
# For backwards compatibility. | ||
self._seen_tokens = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# For backwards compatibility. | |
self._seen_tokens = 0 |
this one is unused throughout the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment below.
src/transformers/cache_utils.py
Outdated
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
"""Returns the sequence length of the cached states that were seen by the model.""" | ||
|
||
return self._seen_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use the fn from the static cache (since we want to remove self._seen_tokens
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be a performance degradation compared to the current integer update of self._seen_tokens
. Is there a plan to remove get_seq_length
? In that case, I would remove it then. Otherwise, this will be a lot slower since it will have to do (synced) CPU operation on the offloaded cache. Let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah get_seq_length is deprecated in favor of the cache_positions
which should not need CPU operations as you only use them on device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, but shouldn't we just remove the get_seq_length
methods including all neccesary variables (_seen_tokens
) once it gets removed from the API? Since there is still quite a few usages throughout the codebase.
What I meant was that get_seq_length
will be much less performant in the meantime if I switch it over to the StaticCache
implementation. Which if you want that in the meantime, I'm fine with as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry, no need to use the one from static, and actually yeah, we should probably just prevent user from using it -> no offloading for them?
Let's go with keeping seentoken for now, add a comment saying #TODO @gante remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should probably just prevent user from using it -> no offloading for them?
The method get_seq_length
works fine but it's used internally still. Hence my hesitation to remove it / revert it to the slower StaticCache
version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
Me and Arthur will handle the deprecation of both on all cache types + internal usage afterwards 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker Please review and merge when happy.
3d413cb
to
b470b63
Compare
@ArthurZucker @gante Thanks for reviewing! Made some fixes and added unit tests in a new commit. Also added the performance testing into the PR description. |
fd41a90
to
890db71
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🤗
Thank you for the cool feature and for iterating with us!
docs/source/en/kv_cache.md
Outdated
@@ -238,6 +238,24 @@ For more examples with Static Cache and JIT compilation, take a look at [StaticC | |||
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of" | |||
``` | |||
|
|||
Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. Just |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. Just | |
### Offloaded Static Cache | |
Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. Just |
I think it deserves a subsection of its own 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do, also, I will add it to the overview table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! Although I do wonder what Initialization Recommended
means in the overview table?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether it should be init outside generate
and passed to generate or not cc @zucchini-nlp on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some cache classes are recommended to initialize outside generation, e.g. StaticCache with compilation had some issues when we initialized cache while compiling.
Also, some cache types are not handled automatically by our API, e.g. SinkCache so the user has no option as to initialize and pass past_key_values
890db71
to
dc8d226
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well sorry for the late review, very very nice, let's make sure the slow tests pass and good to go! Can you try to run them locally? 🤗
self.dtype = dtype if dtype is not None else torch.float32 | ||
|
||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads | ||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads | |
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
@gerbenvv I can also merge like this and commit later if you are busy! |
Yeah trying to run the tests now I'll try running just the file |
Hmm, the tests have literally crashed the whole machine. Is it supposed to use all the GPUs on the machine? I am struggling a bit to run this. |
Then I tried to fix those errors by passing the token & installing
ran succesful and those were the ones that I have changed. |
Okay, making progress ;-) New output of
So no tests regarding the static, dynamic or offloading caches are failing. I think this should be good enough to get this merged, right? |
Yep let's go! 🔥 |
…ace#32161) * Add a static cache that offloads to the CPU or other device * Fix PR comments, add unit-tests
…ace#32161) * Add a static cache that offloads to the CPU or other device * Fix PR comments, add unit-tests
…ace#32161) * Add a static cache that offloads to the CPU or other device * Fix PR comments, add unit-tests
What does this PR do?
This PR adds a static cache that offloads to another device.
Fixes #32179
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @gante @n17s
Performance tests
Performance tested it with:
torch.compile(model)
and I am getting a throughput of about 535 tokens/s (OOM static).
Also with:
which gets 10.6 tokens/s (12.8 tokens/s static).
And with:
which does 98.8 tokens/s (106.8 tokens/s static)
And with:
which does 939.5 tokens/s (995.6 tokens/s static)