-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add gemma 2 #31659
Add gemma 2 #31659
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good! Only need to update to the 2 in the integration tests
class Gemma2ModelTester(GemmaModelTester): | ||
config_class = Gemma2Config | ||
model_class = Gemma2Model | ||
for_causal_lm_class = Gemma2ForCausalLM | ||
for_sequence_class = Gemma2ForSequenceClassification | ||
for_token_class = Gemma2ForTokenClassification |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok!
* inital commit * Add doc * protect? * fixup stuffs * update tests * fix build documentation * mmmmmmm config attributes * style * nit * uodate * nit * Fix docs * protect some stuff --------- Co-authored-by: Lysandre <[email protected]>
Please excuse this if I'm just not reading the code correctly (?), but I'm struggling to understand the intended function of the hybrid cache. Here, in modeling_gemma2.py as well as the two other attention functions, the past keys/values are updated like so: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) While the function definition for def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
if sliding_window:
update_fn = self._sliding_update
else:
update_fn = self._static_update
... It doesn't read the |
that is a good catch and not intended, we will update this and fix it. The typo is that we should either pass the sliding windo directly, or get the sliding window form the cache kwargs. I think this stems from a will to make it compile compatible, and a type! |
What does this PR do?
Adds support for gemma2