Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a static cache that offloads to the CPU or other device #32161

Merged
merged 2 commits into from
Aug 29, 2024

Conversation

gerbenvv
Copy link
Contributor

@gerbenvv gerbenvv commented Jul 23, 2024

What does this PR do?

This PR adds a static cache that offloads to another device.

Fixes #32179

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @gante @n17s

Performance tests

Performance tested it with:

  • L40S (48 GB)
  • Llama 3 70B (4 bit quantized)
  • sdpa attention
  • torch.compile(model)
  • batch size of 1000
  • sequence size of 25 (no prompt, fully generative)

and I am getting a throughput of about 535 tokens/s (OOM static).

Also with:

  • batch size of 1
  • sequence size of 100 (no prompt, fully generative)

which gets 10.6 tokens/s (12.8 tokens/s static).

And with:

  • batch size of 1
  • sequence size of 100 (fully prompt)

which does 98.8 tokens/s (106.8 tokens/s static)

And with:

  • batch size of 50
  • sequence size of 100 (fully prompt)

which does 939.5 tokens/s (995.6 tokens/s static)

@gerbenvv gerbenvv changed the title Add a static cache that offloads to the CPU or other device [WIP] Add a static cache that offloads to the CPU or other device Jul 23, 2024
@gerbenvv gerbenvv force-pushed the offloaded-shared-cache branch 4 times, most recently from 87541c5 to 8b862ab Compare July 25, 2024 11:30
@gerbenvv gerbenvv changed the title [WIP] Add a static cache that offloads to the CPU or other device Add a static cache that offloads to the CPU or other device Jul 25, 2024
@gerbenvv gerbenvv force-pushed the offloaded-shared-cache branch 2 times, most recently from 386f231 to 3d413cb Compare July 25, 2024 12:13
@ArthurZucker ArthurZucker self-requested a review July 26, 2024 10:25
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already looks great! IMO would be nice to add a snippet of how to use it in the doc of the class + once #32150 is merged also add this as compatible with torch.compile (while the non static version won't be)

Really like the findings you posted on the issue, will be useful for everyone I think!

@ArthurZucker
Copy link
Collaborator

Also would be interesting to test this / showcase potential for huge beamsearch !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very very cool! 🔥

As Arthur wrote, this PR is now missing the complementary parts:

  • An example in the docstring
  • Some benchmarks in the PR, for future reference
  • An integration test to prevent regressions (like the ones here)

src/transformers/cache_utils.py Show resolved Hide resolved
Comment on lines 1434 to 1762
# For backwards compatibility.
self._seen_tokens = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# For backwards compatibility.
self._seen_tokens = 0

this one is unused throughout the code

Copy link
Contributor Author

@gerbenvv gerbenvv Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below.

Comment on lines 1522 to 1854
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""

return self._seen_tokens
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use the fn from the static cache (since we want to remove self._seen_tokens)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be a performance degradation compared to the current integer update of self._seen_tokens. Is there a plan to remove get_seq_length ? In that case, I would remove it then. Otherwise, this will be a lot slower since it will have to do (synced) CPU operation on the offloaded cache. Let me know what you think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah get_seq_length is deprecated in favor of the cache_positions which should not need CPU operations as you only use them on device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, but shouldn't we just remove the get_seq_length methods including all neccesary variables (_seen_tokens) once it gets removed from the API? Since there is still quite a few usages throughout the codebase.

What I meant was that get_seq_length will be much less performant in the meantime if I switch it over to the StaticCache implementation. Which if you want that in the meantime, I'm fine with as well.

Copy link
Collaborator

@ArthurZucker ArthurZucker Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, no need to use the one from static, and actually yeah, we should probably just prevent user from using it -> no offloading for them?
Let's go with keeping seentoken for now, add a comment saying #TODO @gante remove this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. 👍

Copy link
Contributor Author

@gerbenvv gerbenvv Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably just prevent user from using it -> no offloading for them?

The method get_seq_length works fine but it's used internally still. Hence my hesitation to remove it / revert it to the slower StaticCache version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Me and Arthur will handle the deprecation of both on all cache types + internal usage afterwards 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker Please review and merge when happy.

src/transformers/cache_utils.py Outdated Show resolved Hide resolved
@gerbenvv
Copy link
Contributor Author

gerbenvv commented Aug 6, 2024

@ArthurZucker @gante Thanks for reviewing! Made some fixes and added unit tests in a new commit. Also added the performance testing into the PR description.

@gerbenvv gerbenvv force-pushed the offloaded-shared-cache branch 5 times, most recently from fd41a90 to 890db71 Compare August 7, 2024 12:42
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🤗

Thank you for the cool feature and for iterating with us!

@@ -238,6 +238,24 @@ For more examples with Static Cache and JIT compilation, take a look at [StaticC
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
```

Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. Just
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. Just
### Offloaded Static Cache
Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. Just

I think it deserves a subsection of its own 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do, also, I will add it to the overview table.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Although I do wonder what Initialization Recommended means in the overview table?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether it should be init outside generate and passed to generate or not cc @zucchini-nlp on this!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some cache classes are recommended to initialize outside generation, e.g. StaticCache with compilation had some issues when we initialized cache while compiling.

Also, some cache types are not handled automatically by our API, e.g. SinkCache so the user has no option as to initialize and pass past_key_values

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well sorry for the late review, very very nice, let's make sure the slow tests pass and good to go! Can you try to run them locally? 🤗

self.dtype = dtype if dtype is not None else torch.float32

# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)

@ArthurZucker
Copy link
Collaborator

@gerbenvv I can also merge like this and commit later if you are busy!

@gerbenvv
Copy link
Contributor Author

@gerbenvv I can also merge like this and commit later if you are busy!

Yeah trying to run the tests now RUN_SLOW=1 pytest tests/utils/test_cache_utils.py but it's taking quite a while.
I also tried to install the full dev environment but it failed to install av due to some compile error.

I'll try running just the file pytest tests/utils/test_cache_utils.py and if that doesn't work for me, the tests that I have changed/added. Since it is basically a new class, that should be fine.

@gerbenvv
Copy link
Contributor Author

Hmm, the tests have literally crashed the whole machine. Is it supposed to use all the GPUs on the machine?

I am struggling a bit to run this.

@gerbenvv
Copy link
Contributor Author

RUN_SLOW=1 pytest tests/utils/test_cache_utils.py gave me:

========================================================================================================= short test summary info ==========================================================================================================
FAILED tests/utils/test_cache_utils.py::CacheTest::test_static_cache_exportability - OSError: You are trying to access a gated repo.                                                                                                        
FAILED tests/utils/test_cache_utils.py::CacheIntegrationTest::test_dynamic_cache_batched - OSError: You are trying to access a gated repo.                                                                                                  
FAILED tests/utils/test_cache_utils.py::CacheIntegrationTest::test_dynamic_cache_beam_search - OSError: You are trying to access a gated repo.                                                                                              
FAILED tests/utils/test_cache_utils.py::CacheIntegrationTest::test_dynamic_cache_hard - OSError: You are trying to access a gated repo.                                                                                                     
FAILED tests/utils/test_cache_utils.py::CacheIntegrationTest::test_hybrid_cache_n_sequences - OSError: You are trying to access a gated repo.                                                                                               
FAILED tests/utils/test_cache_utils.py::CacheIntegrationTest::test_offloaded_cache_equivalent_to_dynamic_cache - ImportError: Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`                
FAILED tests/utils/test_cache_utils.py::CacheIntegrationTest::test_offloaded_cache_uses_less_memory_than_dynamic_cache - ImportError: Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`        
FAILED tests/utils/test_cache_utils.py::CacheIntegrationTest::test_sink_cache_iterative_prompts - ImportError: Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`                               
===================================================================================== 8 failed, 14 passed, 2 skipped, 24 warnings in 627.82s (0:10:27) ====================================================================================

Then I tried to fix those errors by passing the token & installing accelerate but then the machine crashed. But we at least know that:

test_static_cache_greedy_decoding_pad_left
test_static_cache_greedy_decoding_pad_right
test_static_cache_extra_left_padding

ran succesful and those were the ones that I have changed.

@gerbenvv
Copy link
Contributor Author

gerbenvv commented Aug 28, 2024

Okay, making progress ;-)

New output of CUDA_VISIBLE_DEVICES=0 RUN_SLOW=1 pytest tests/utils/test_cache_utils.py is:

========================================================================================================= short test summary info ==========================================================================================================
FAILED tests/utils/test_cache_utils.py::CacheTest::test_static_cache_exportability - torch._export.verifier.SpecViolationError: Node.meta _enter_autocast is missing val field.
FAILED tests/utils/test_cache_utils.py::CacheIntegrationTest::test_hybrid_cache_n_sequences - ValueError: Greedy methods without beam search do not support `num_return_sequences` different than 1 (got 2).
FAILED tests/utils/test_cache_utils.py::CacheIntegrationTest::test_sink_cache_iterative_prompts - AssertionError: False is not true
===================================================================================== 3 failed, 19 passed, 2 skipped, 28 warnings in 756.35s (0:12:36) =====================================================================================

So no tests regarding the static, dynamic or offloading caches are failing. I think this should be good enough to get this merged, right?

@ArthurZucker
Copy link
Collaborator

Yep let's go! 🔥

@ArthurZucker ArthurZucker merged commit 5129671 into huggingface:main Aug 29, 2024
24 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
…ace#32161)

* Add a static cache that offloads to the CPU or other device

* Fix PR comments, add unit-tests
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
…ace#32161)

* Add a static cache that offloads to the CPU or other device

* Fix PR comments, add unit-tests
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
…ace#32161)

* Add a static cache that offloads to the CPU or other device

* Fix PR comments, add unit-tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Static KV cache with CPU offloading
5 participants