Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: SinkCache can handle iterative prompts #27907

Merged
merged 9 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states, if there is any."""
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")


class DynamicCache(Cache):
"""
Expand Down Expand Up @@ -120,6 +124,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return 0
return self.key_cache[layer_idx].shape[-2]

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
Expand Down Expand Up @@ -209,8 +217,11 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx:
return 0
cache_length = self.key_cache[layer_idx].shape[-2]
return min(cache_length, self.window_length - 1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was a bit of a hack, get_max_length makes us no longer need the hack :)

get_seq_length now always does what the fn name and the docstring say it does.

return self.key_cache[layer_idx].shape[-2]

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.window_length

def update(
self,
Expand Down Expand Up @@ -239,8 +250,8 @@ def update(
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
sin = cache_kwargs.get("sin")[: self.window_length]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slicing here is needed if more than one token is fed at once, after the cache is full.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sin = cache_kwargs.get("sin")[: self.window_length]
sin = cache_kwargs.get("sin")[-self.window_length:]

would that not make more sense? since that's the side we split the cache from no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll run some tests for this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated my previous message with my findings

Copy link
Member Author

@gante gante Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker @tomaarsen that's the beauty of sin and cos: the rerotation applied on the sink caches is based on the relative angles, and the relative angles are the same regardless of the side we slice 🙌

If you place a debugger, you can see that the sliced tensors are different, but the rerotation coefficients are exactly the same!

cos = cache_kwargs.get("cos")[: self.window_length]
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
using_rope = cos is not None and sin is not None

Expand Down
12 changes: 12 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,12 @@ def forward(
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# If the cache has a fixed length and we were about to go beyond it, update the key/value length and the
# attention mask accordingly. `.update()` handles the cache cropping internally if needed.
kv_max_length = past_key_value.get_max_length()
if kv_max_length is not None and kv_seq_len > kv_max_length:
kv_seq_len = kv_max_length
attention_mask = attention_mask[:, :, :, -kv_seq_len:]
Copy link
Member Author

@gante gante Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attention mask must be sliced to match the length of key_states, which might have been sliced in .update() (for fixed-length caches)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch not a fan of that.
We have something similar for the mistral sliced window but not in favor of keeping this. That should either go in the attention convert, which should slice it, or in the cache_kwargs as it's clearly sink cache and window cache specific 😉

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative would be to make prepare_inputs_for_generation to prepare the sliced mask in advance (which is how TF/JAX do it). Going to check if it is feasible


key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down Expand Up @@ -511,6 +517,12 @@ def forward(
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# If the cache has a fixed length and we were about to go beyond it, update the key/value length and the
# attention mask accordingly. `.update()` handles the cache cropping internally if needed.
kv_max_length = past_key_value.get_max_length()
if kv_max_length is not None and kv_seq_len > kv_max_length:
kv_seq_len = kv_max_length
attention_mask = attention_mask[:, :, :, -kv_seq_len:]

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ def forward(
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# If the cache has a fixed length and we were about to go beyond it, update the key/value length and the
# attention mask accordingly. `.update()` handles the cache cropping internally if needed.
kv_max_length = past_key_value.get_max_length()
if kv_max_length is not None and kv_seq_len > kv_max_length:
kv_seq_len = kv_max_length
attention_mask = attention_mask[:, :, :, -kv_seq_len:]

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand Down Expand Up @@ -408,6 +414,12 @@ def forward(

cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# If the cache has a fixed length and we were about to go beyond it, update the key/value length and the
# attention mask accordingly. `.update()` handles the cache cropping internally if needed.
kv_max_length = past_key_value.get_max_length()
if kv_max_length is not None and kv_seq_len > kv_max_length:
kv_seq_len = kv_max_length
attention_mask = attention_mask[:, :, :, -kv_seq_len:]

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ def forward(
# Specific to RoPE models with partial rotation
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# If the cache has a fixed length and we were about to go beyond it, update the key/value length and the
# attention mask accordingly. `.update()` handles the cache cropping internally if needed.
kv_max_length = past_key_value.get_max_length()
if kv_max_length is not None and kv_seq_len > kv_max_length:
kv_seq_len = kv_max_length
attention_mask = attention_mask[:, :, :, -kv_seq_len:]

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

Expand Down
12 changes: 12 additions & 0 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,12 @@ def forward(
# Specific to RoPE models with partial rotation
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# If the cache has a fixed length and we were about to go beyond it, update the key/value length and the
# attention mask accordingly. `.update()` handles the cache cropping internally if needed.
kv_max_length = past_key_value.get_max_length()
if kv_max_length is not None and kv_seq_len > kv_max_length:
kv_seq_len = kv_max_length
attention_mask = attention_mask[:, :, :, -kv_seq_len:]

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

Expand Down Expand Up @@ -466,6 +472,12 @@ def forward(
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# If the cache has a fixed length and we were about to go beyond it, update the key/value length and the
# attention mask accordingly. `.update()` handles the cache cropping internally if needed.
kv_max_length = past_key_value.get_max_length()
if kv_max_length is not None and kv_seq_len > kv_max_length:
kv_seq_len = kv_max_length
attention_mask = attention_mask[:, :, :, -kv_seq_len:]

tgt_len = key_states.shape[2]

Expand Down
41 changes: 41 additions & 0 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,44 @@ def test_sink_cache_hard(self):
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))

@require_auto_gptq
def test_sink_cache_iterative_prompts(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test would fail on main

"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
tokenizer = AutoTokenizer.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ")
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ", device_map="auto")
prompt = (
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
"and must-see attractions."
)

# Prepare generation settings
cache = SinkCache(window_length=256, num_sink_tokens=4)
input_ids = torch.tensor([], device=model.device, dtype=torch.int)
for _ in range(3):
# Tokenize the prompt with the correct chat template
chat = [{"role": "user", "content": prompt}]
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
)
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)

# Perform the generation
gen_out = model.generate(
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
)
input_ids = gen_out

# We went well beyond the cache length
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)

# And it still produces a coherent english (the repetition is due to the prompt being repeated 3 times)
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
last_output = (
"<|assistant|>\nHawaii, the Aloha State post for a travel destination you've taken. Your post's and "
"must-see landmarks. Use a descriptive and engaging writing style, incorporating personal anecdotes and "
"recommendations for fellow travelers. Your post should be at least 800 words and include high-quality "
"images to enhance the reader's experience. Be sure to cover a variety of experiences, from cultural "
"immersion to outdoor adventures, and provide practical"
)
self.assertTrue(decoded[0].endswith(last_output))
Loading