-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Starcoder2 model #29120
Starcoder2 model #29120
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥 looks very good!
- Let's not support mistral
- Let's try to take in the new API from Fix static generation when compiling! #28937 and [
Core generation
] Adds support for static KV cache #27931
The rest is pretty much alright!
return self.weight * hidden_states.to(input_dtype) | ||
|
||
|
||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Starcoder2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix copies will not let this pass, should be copied from Mistral as we changed llama for compiled static cache.
I would also rather we support static cache as the API got quite a lot cleaner
return torch.cat((-x2, x1), dim=-1) | ||
|
||
|
||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here llama is different make fix-copies
will help you fix this !
return hidden_states | ||
|
||
|
||
class Starcoder2GatedMLP(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably missing copied from mention here (mistral)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has small changes (bias + dropout I think)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove the copied mention from all the classes/methods where we added dropout?
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention.forward |
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Starcoder2 |
# Copied from transformers.models.mistral.modeling_mistral.MistralModel.forward with MISTRAL->STARCODER2,Mistral->Starcoder2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes otherwise the check-copies
will fail 😉
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | ||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
this is not used in Mistral anyways
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | ||
|
||
|
||
class Starcoder2Attention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would make sense to follow the llama implementation IMO for static cache (with the additional cache positions) but this can go in another PR no worries 🤗
self.self_attn = STARCODER2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) | ||
|
||
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](config) | ||
|
||
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type]( | ||
config.hidden_size, eps=config.norm_epsilon | ||
) | ||
self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type]( | ||
config.hidden_size, eps=config.norm_epsilon | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not what we usually do in transformers. The attention is a specific case 😅
- are all of these used in the default starcoder?
- if not then let's not support mistral. Mistral is a different architecture
The reason why attention is allowed is because it uses the same parameters -> same "Attention" with different forward vs here it's really a different architecture = againsttransformers
philosophy
if self._attn_implementation == "flash_attention_2": | ||
# 2d mask is passed through the layers | ||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None | ||
elif self._attn_implementation == "sdpa" and not output_attentions: | ||
# output_attentions=True can not be supported when using SDPA, and we fall back on | ||
# the manual implementation that requires a 4D causal mask in all cases. | ||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( | ||
attention_mask, | ||
(batch_size, seq_length), | ||
inputs_embeds, | ||
past_key_values_length, | ||
) | ||
else: | ||
# 4d mask is passed through the layers | ||
attention_mask = _prepare_4d_causal_attention_mask( | ||
attention_mask, | ||
(batch_size, seq_length), | ||
inputs_embeds, | ||
past_key_values_length, | ||
sliding_window=self.config.sliding_window, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see the new Llama code for this which was simpliefied. I'd rather we take it directly for the attention 😉
@unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test") | ||
def test_save_load_fast_init_from_base(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might have missed this but have not seen where these complex number buffers are?
I re-created a PR here since Joel is on vacation: #29215 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Closing as #29215 was merged and starcoder 2 is officially supported |
The Starcoder2 model, adapted from Mistral. All changes are done through options, so Mistral itself is still supported. Main changes:
*Embedding and residual dropout
It does not support absolute embeddings, so can't support Santacoder or Starcoder
Todo:
Core generation
] Adds support for static KV cache #27931, [CLeanup
] Revert SDPA attention changes that got in the static kv cache PR #29027 (and future changes from Feb. 19)@younesbelkada