Skip to content

Commit

Permalink
Refactor StarCoder2 using modular (#34015)
Browse files Browse the repository at this point in the history
* Create modular_starcoder2.py

* Update modular_starcoder2.py

* update

* finalize modular

* revert # no-unravel

* Add support

* style

* Update modular_model_converter.py

* update docstring
  • Loading branch information
Cyrilvallez authored Nov 21, 2024
1 parent 1887159 commit 4e90b99
Show file tree
Hide file tree
Showing 3 changed files with 643 additions and 66 deletions.
84 changes: 31 additions & 53 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/starcoder2/modular_starcoder2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_starcoder2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
Expand All @@ -17,20 +23,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Starcoder2 model."""

import math
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -56,12 +60,10 @@


logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b"
_CONFIG_FOR_DOC = "Starcoder2Config"


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Starcoder2
class Starcoder2RotaryEmbedding(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -149,15 +151,30 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# Copied from transformers.models.llama.modeling_llama.rotate_half
class Starcoder2MLP(nn.Module):
def __init__(self, config: Starcoder2Config):
super().__init__()
embed_dim = config.hidden_size
self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
self.act = ACT2FN[config.hidden_act]
self.residual_dropout = config.residual_dropout

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
return hidden_states


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down Expand Up @@ -185,24 +202,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed


class Starcoder2MLP(nn.Module):
def __init__(self, config: Starcoder2Config):
super().__init__()
embed_dim = config.hidden_size
self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
self.act = ACT2FN[config.hidden_act]
self.residual_dropout = config.residual_dropout

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
return hidden_states


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
Expand Down Expand Up @@ -331,7 +330,6 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand All @@ -340,7 +338,6 @@ def __init__(self, *args, **kwargs):
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -406,7 +403,7 @@ def forward(
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

# Reashape to the expected shape for Flash Attention
# Reshape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
Expand Down Expand Up @@ -434,15 +431,13 @@ def forward(
return attn_output, attn_weights, past_key_value


# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2
class Starcoder2SdpaAttention(Starcoder2Attention):
"""
Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Starcoder2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""

# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -552,7 +547,6 @@ def __init__(self, config: Starcoder2Config, layer_idx: int):
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)

# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -642,7 +636,6 @@ def forward(
"The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.",
STARCODER2_START_DOCSTRING,
)
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Starcoder2
class Starcoder2PreTrainedModel(PreTrainedModel):
config_class = Starcoder2Config
base_model_prefix = "model"
Expand Down Expand Up @@ -760,14 +753,15 @@ def __init__(self, config: Starcoder2Config):
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embedding_dropout = config.embedding_dropout
self.layers = nn.ModuleList(
[Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._attn_implementation = config._attn_implementation
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.rotary_emb = Starcoder2RotaryEmbedding(config=config)

self.gradient_checkpointing = False
self.embedding_dropout = config.embedding_dropout
# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -904,7 +898,6 @@ def forward(
attentions=all_self_attns,
)

# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
Expand Down Expand Up @@ -981,7 +974,6 @@ def _update_causal_mask(
return causal_mask

@staticmethod
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Starcoder2
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
Expand Down Expand Up @@ -1049,7 +1041,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM with QWEN2->STARCODER2,Qwen2->Starcoder2
class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand Down Expand Up @@ -1082,7 +1073,6 @@ def get_decoder(self):

@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
# Ignore copy
def forward(
self,
input_ids: torch.LongTensor = None,
Expand All @@ -1097,6 +1087,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand All @@ -1117,8 +1108,8 @@ def forward(
```python
>>> from transformers import AutoTokenizer, Starcoder2ForCausalLM
>>> model = Starcoder2ForCausalLM.from_pretrained("bigcode/starcoder2-7b")
>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b")
>>> model = Starcoder2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
Expand Down Expand Up @@ -1155,18 +1146,7 @@ def forward(

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Ensure tensors are on the same device
shift_labels = shift_labels.to(shift_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -1196,7 +1176,6 @@ def forward(
""",
STARCODER2_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Starcoder2, LLAMA->STARCODER2
class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -1293,7 +1272,6 @@ def forward(
""",
STARCODER2_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Starcoder2, LLAMA->STARCODER2
class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand Down
Loading

0 comments on commit 4e90b99

Please sign in to comment.