Skip to content

Commit

Permalink
Fix PR comments, add unit-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gerbenvv committed Aug 6, 2024
1 parent 5e2f832 commit fd41a90
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 30 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_seq_length
- reset

[[autodoc]] OffloadedStaticCache
- update
- get_seq_length
- reset

[[autodoc]] HybridCache
- update
- reset
Expand Down
24 changes: 21 additions & 3 deletions docs/source/en/kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ I like rock music because it's loud and energetic. It's a great way to express m
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```

## OffloadedCache
## Offloaded Cache

Similarly to KV cache quantization, [`~OffloadedCache`] strategy aims to reduce GPU VRAM usage.
It does so by moving the KV cache for most layers to the CPU.
Expand All @@ -154,7 +154,8 @@ Thus, it can serve as a drop-in replacement or a fallback for it.
Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.)
you may notice a small degradation in generation throughput compared to the default KV cache implementation.

To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directky to the `generate()` call.
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directly to the `generate()` call.
Use `cache_implementation="offloaded-static"` for an offloaded static cache (see also [Static Cache](#static-cache) below).

```python
>>> import torch
Expand Down Expand Up @@ -216,7 +217,6 @@ retrying with cache_implementation='offloaded'
before successfully generating 40 beams.



### Static Cache

Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates
Expand All @@ -238,6 +238,24 @@ For more examples with Static Cache and JIT compilation, take a look at [StaticC
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
```

Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. Just
pass `cache_implementation="offloaded-static"` in the `generation_config` or directly to the `generate()` call.
This will use the [`~OffloadedStaticCache`] implementation instead.

```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
>>> inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)

>>> # simply pass the cache implementation="static"
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="offloaded-static")
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
```

### Sliding Window Cache

As the name suggests, this cache type implements a sliding window over previous keys and values, retaining only the last `sliding_window` tokens. It should be used with models like Mistral that support sliding window attention. Additionally, similar to Static Cache, this one is JIT-friendly and can be used with the same compile tecniques as Static Cache.
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,7 @@
"HybridCache",
"MambaCache",
"OffloadedCache",
"OffloadedStaticCache",
"QuantizedCache",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
Expand Down Expand Up @@ -5955,6 +5956,7 @@
HybridCache,
MambaCache,
OffloadedCache,
OffloadedStaticCache,
QuantizedCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
Expand Down
53 changes: 34 additions & 19 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,6 +1657,22 @@ class OffloadedStaticCache(StaticCache):
Static cache class to be used with `torch.compile(model)` that offloads to the CPU or
another device.
Args:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize
the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`Union[str, torch.device]`):
The device on which the cache should be initialized. Should be the same as the
layer device.
dtype (`torch.dtype`, *optional*):
The default `dtype` to use when initializing the cache.
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
The device to offload to. Defaults to CPU.
Attributes:
key_cache (`List[torch.Tensor]`):
Off-loaded key cache tensors. First one will be on device, where-as the others are
Expand All @@ -1674,6 +1690,24 @@ class OffloadedStaticCache(StaticCache):
The device used to offload to.
dtype (`torch.dtype`):
The `dtype` used to initializing the cache.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""

def __init__(
Expand All @@ -1685,25 +1719,6 @@ def __init__(
dtype: Optional[torch.dtype] = None,
offload_device: Union[str, torch.device] = torch.device("cpu"),
) -> None:
"""Create offloading static cache.
Args:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize
the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`, `optional`):
The maximum sequence length with which the model will be used.
device (`Union[str, torch.device]`):
The device on which the cache should be initialized. Should be the same as the
layer device.
dtype (`torch.dtype`, `optional`):
The default `dtype` to use when initializing the cache.
offload_device (`Union[str, torch.device]`, `optional`):
The device to offload to. Defaults to CPU.
"""

self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device)
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class OffloadedStaticCache(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class QuantizedCache(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
36 changes: 28 additions & 8 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,15 @@ def test_sink_cache_iterative_prompts(self):
self.assertTrue(decoded[0].endswith(last_output))

@require_torch_gpu
@parameterized.expand(["eager", "sdpa"])
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
@parameterized.expand(
[
("eager", "static"),
("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
]
)
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
Expand All @@ -399,7 +406,7 @@ def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
self.assertListEqual(decoded, EXPECTED_GENERATION)

set_seed(0)
model.generation_config.cache_implementation = "static"
model.generation_config.cache_implementation = cache_implementation
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
Expand All @@ -413,8 +420,15 @@ def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
self.assertListEqual(decoded, EXPECTED_GENERATION)

@require_torch_gpu
@parameterized.expand(["eager", "sdpa"])
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation):
@parameterized.expand(
[
("eager", "static"),
("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
]
)
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
EXPECTED_GENERATION = [
"The best color isЋ the one that complements the skin tone of",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
Expand All @@ -439,7 +453,7 @@ def test_static_cache_greedy_decoding_pad_right(self, attn_implementation):
self.assertListEqual(decoded, EXPECTED_GENERATION)

set_seed(0)
model.generation_config.cache_implementation = "static"
model.generation_config.cache_implementation = cache_implementation
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
Expand Down Expand Up @@ -499,7 +513,13 @@ def test_dynamic_cache_extra_left_padding(self):
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)

def test_static_cache_extra_left_padding(self):
@parameterized.expand(
[
"static",
"offloaded-static",
]
)
def test_static_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the static cache"""
EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the",
Expand All @@ -517,7 +537,7 @@ def test_static_cache_extra_left_padding(self):
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)

model.generation_config.cache_implementation = "static"
model.generation_config.cache_implementation = cache_implementation

gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
Expand Down

0 comments on commit fd41a90

Please sign in to comment.