Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gemma 2 #31659

Merged
merged 16 commits into from
Jun 27, 2024
Merged

Add gemma 2 #31659

merged 16 commits into from
Jun 27, 2024

Conversation

ArthurZucker
Copy link
Collaborator

What does this PR do?

Adds support for gemma2

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good! Only need to update to the 2 in the integration tests

Comment on lines +50 to +55
class Gemma2ModelTester(GemmaModelTester):
config_class = Gemma2Config
model_class = Gemma2Model
for_causal_lm_class = Gemma2ForCausalLM
for_sequence_class = Gemma2ForSequenceClassification
for_token_class = Gemma2ForTokenClassification
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok!

@ArthurZucker ArthurZucker marked this pull request as ready for review June 27, 2024 15:18
@LysandreJik LysandreJik merged commit 0cf60f1 into main Jun 27, 2024
8 of 25 checks passed
@LysandreJik LysandreJik deleted the add-gemma-2 branch June 27, 2024 15:36
LysandreJik added a commit that referenced this pull request Jun 27, 2024
* inital commit

* Add doc

* protect?

* fixup stuffs

* update tests

* fix build documentation

* mmmmmmm config attributes

* style

* nit

* uodate

* nit

* Fix docs

* protect some stuff

---------

Co-authored-by: Lysandre <[email protected]>
@turboderp
Copy link
Contributor

Please excuse this if I'm just not reading the code correctly (?), but I'm struggling to understand the intended function of the hybrid cache. Here, in modeling_gemma2.py as well as the two other attention functions, the past keys/values are updated like so:

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "sliding_window": self.sliding_window,
                "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

While the function definition for HybridCache.update looks like:

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
        sliding_window: Optional[int] = None,
    ) -> Tuple[torch.Tensor]:
        cache_position = cache_kwargs.get("cache_position")
        self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
        self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
        k_out = self.key_cache[layer_idx]
        v_out = self.value_cache[layer_idx]
        if sliding_window:
            update_fn = self._sliding_update
        else:
            update_fn = self._static_update
    ...

It doesn't read the sliding_window argument from the kwargs, so the default value is always used and the _sliding_update function is never selected, even on layers that use a sliding_window. Is this right?

@ArthurZucker
Copy link
Collaborator Author

that is a good catch and not intended, we will update this and fix it. The typo is that we should either pass the sliding windo directly, or get the sliding window form the cache kwargs. I think this stems from a will to make it compile compatible, and a type!

@ArthurZucker ArthurZucker mentioned this pull request Jul 26, 2024
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants