Skip to content

Commit

Permalink
Add gemma 2 (#31659)
Browse files Browse the repository at this point in the history
* inital commit

* Add doc

* protect?

* fixup stuffs

* update tests

* fix build documentation

* mmmmmmm config attributes

* style

* nit

* uodate

* nit

* Fix docs

* protect some stuff

---------

Co-authored-by: Lysandre <[email protected]>
  • Loading branch information
ArthurZucker and LysandreJik committed Jun 27, 2024
1 parent be50a03 commit 69b0f44
Show file tree
Hide file tree
Showing 24 changed files with 3,057 additions and 69 deletions.
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Funnel Transformer](model_doc/funnel) ||||
| [Fuyu](model_doc/fuyu) ||||
| [Gemma](model_doc/gemma) ||||
| [Gemma2](model_doc/gemma2) ||||
| [GIT](model_doc/git) ||||
| [GLPN](model_doc/glpn) ||||
| [GPT Neo](model_doc/gpt_neo) ||||
Expand Down
58 changes: 58 additions & 0 deletions docs/source/en/model_doc/gemma2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Gemma2

## Overview

The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/Gemma2-open-models/) by Gemma2 Team, Google.
Gemma2 models are trained on 6T tokens, and released with 2 versions, 2b and 7b.

The abstract from the paper is the following:

*This work introduces Gemma2, a new family of open language models demonstrating strong performance across academic benchmarks for language understanding, reasoning, and safety. We release two sizes of models (2 billion and 7 billion parameters), and provide both pretrained and fine-tuned checkpoints. Gemma2 outperforms similarly sized open models on 11 out of 18 text-based tasks, and we present comprehensive evaluations of safety and responsibility aspects of the models, alongside a detailed description of our model development. We believe the responsible release of LLMs is critical for improving the safety of frontier models, and for enabling the next wave of LLM innovations*

Tips:

- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py`

This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen]().


## Gemma2Config

[[autodoc]] Gemma2Config

## Gemma2Model

[[autodoc]] Gemma2Model
- forward

## Gemma2ForCausalLM

[[autodoc]] Gemma2ForCausalLM
- forward

## Gemma2ForSequenceClassification

[[autodoc]] Gemma2ForSequenceClassification
- forward

## Gemma2ForTokenClassification

[[autodoc]] Gemma2ForTokenClassification
- forward
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@
],
"models.fuyu": ["FuyuConfig"],
"models.gemma": ["GemmaConfig"],
"models.gemma2": ["Gemma2Config"],
"models.git": [
"GitConfig",
"GitProcessor",
Expand Down Expand Up @@ -2181,6 +2182,15 @@
"GemmaPreTrainedModel",
]
)
_import_structure["models.gemma2"].extend(
[
"Gemma2ForCausalLM",
"Gemma2ForSequenceClassification",
"Gemma2ForTokenClassification",
"Gemma2Model",
"Gemma2PreTrainedModel",
]
)
_import_structure["models.git"].extend(
[
"GitForCausalLM",
Expand Down Expand Up @@ -5062,6 +5072,7 @@
)
from .models.fuyu import FuyuConfig
from .models.gemma import GemmaConfig
from .models.gemma2 import Gemma2Config
from .models.git import (
GitConfig,
GitProcessor,
Expand Down Expand Up @@ -6694,6 +6705,13 @@
GemmaModel,
GemmaPreTrainedModel,
)
from .models.gemma2 import (
Gemma2ForCausalLM,
Gemma2ForSequenceClassification,
Gemma2ForTokenClassification,
Gemma2Model,
Gemma2PreTrainedModel,
)
from .models.git import (
GitForCausalLM,
GitModel,
Expand Down
122 changes: 122 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,125 @@ def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None


class HybridCache(Cache):
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)

self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.is_sliding = torch.tensor(
[i % 2 for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = (
max_batch_size,
self.num_key_value_heads,
min(config.sliding_window, max_cache_len),
self.head_dim,
)
for i in range(config.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] > max_cache_len:
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states

slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, max_cache_len - 1)
to_shift = cache_position >= max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]

k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out

def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
if sliding_window:
update_fn = self._sliding_update
else:
update_fn = self._static_update

return update_fn(
cache_position,
layer_idx,
key_states,
value_states,
k_out,
v_out,
k_out.shape[2],
)

def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return self.max_cache_len

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return None

def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
2 changes: 1 addition & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def __init__(self, **kwargs):
# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None:
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
Expand Down
17 changes: 11 additions & 6 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Cache,
DynamicCache,
HQQQuantizedCache,
HybridCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
Expand Down Expand Up @@ -112,7 +113,7 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}


Expand Down Expand Up @@ -1395,10 +1396,12 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):

past_length = 0
if model_kwargs.get("past_key_values") is not None:
if isinstance(model_kwargs["past_key_values"], Cache):
past_length = model_kwargs["past_key_values"].get_seq_length()
else:
past_length = model_kwargs["past_key_values"][0][0].shape[2]
cache = model_kwargs["past_key_values"]
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length"):
past_length = cache.get_seq_length()

if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
Expand Down Expand Up @@ -1739,7 +1742,9 @@ def generate(
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
funnel,
fuyu,
gemma,
gemma2,
git,
glpn,
gpt2,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
("funnel", "FunnelConfig"),
("fuyu", "FuyuConfig"),
("gemma", "GemmaConfig"),
("gemma2", "Gemma2Config"),
("git", "GitConfig"),
("glpn", "GLPNConfig"),
("gpt-sw3", "GPT2Config"),
Expand Down Expand Up @@ -385,6 +386,7 @@
("funnel", "Funnel Transformer"),
("fuyu", "Fuyu"),
("gemma", "Gemma"),
("gemma2", "Gemma2"),
("git", "GIT"),
("glpn", "GLPN"),
("gpt-sw3", "GPT-Sw3"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
("fsmt", "FSMTModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("gemma", "GemmaModel"),
("gemma2", "Gemma2Model"),
("git", "GitModel"),
("glpn", "GLPNModel"),
("gpt-sw3", "GPT2Model"),
Expand Down Expand Up @@ -454,6 +455,7 @@
("falcon", "FalconForCausalLM"),
("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"),
("git", "GitForCausalLM"),
("gpt-sw3", "GPT2LMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
Expand Down Expand Up @@ -863,6 +865,7 @@
("fnet", "FNetForSequenceClassification"),
("funnel", "FunnelForSequenceClassification"),
("gemma", "GemmaForSequenceClassification"),
("gemma2", "Gemma2ForSequenceClassification"),
("gpt-sw3", "GPT2ForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
Expand Down Expand Up @@ -1044,6 +1047,7 @@
("fnet", "FNetForTokenClassification"),
("funnel", "FunnelForTokenClassification"),
("gemma", "GemmaForTokenClassification"),
("gemma2", "Gemma2ForTokenClassification"),
("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"gemma2",
(
"GemmaTokenizer" if is_sentencepiece_available() else None,
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma/diff_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = 1 / math.sqrt(config.head_dim)

if self.hidden_size % self.num_heads != 0:
raise ValueError(
Expand Down Expand Up @@ -305,7 +306,7 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
Expand Down
Loading

0 comments on commit 69b0f44

Please sign in to comment.