Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma 2: support assisted generation #32357

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def __init__(
"Please pass in `min_length` into `.generate()` instead"
)

# We need to roll back the cache in assisted generation, only DynamicCache is supported
self.generation_config.cache_implementation = None

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,20 @@ def generate(
cache_name = "cache_params"
else:
cache_name = "past_key_values"

# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
# which is only supported in dynamic caches atm
Comment on lines +1783 to +1784
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool this sounds challenging!

if (
assistant_model is not None
and generation_config.cache_implementation is not None
and self._supports_default_dynamic_cache()
):
logger.warning_once(
"An assistant model is provided, using a dynamic cache instead of a cache of type="
f"'{generation_config.cache_implementation}'."
)
generation_config.cache_implementation = None

if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling():
raise ValueError(
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache
from ...cache_utils import Cache, HybridCache
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -584,10 +584,9 @@ class Gemma2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = False
_supports_cache_class = True
_supports_quantized_cache = False
_supports_static_cache = True
_is_stateful = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down Expand Up @@ -832,7 +831,7 @@ def _update_causal_mask(
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None:
if isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_length()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down
Loading