Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Starcoder 2 #502

Merged
merged 25 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
85d9aa6
Add Starcoder2 model and update utils.py
Muhtasham Feb 22, 2024
973e4a7
Refactor model arguments and modules in starcoder2.py
Muhtasham Feb 22, 2024
6cd9870
Refactor FeedForward class to MLP in starcoder2.py
Muhtasham Feb 22, 2024
a96753c
Merge branch 'ml-explore:main' into add/sc2
Muhtasham Feb 28, 2024
73ad35f
Fix typo
Muhtasham Feb 28, 2024
3c0fbd5
pre-commit
Muhtasham Feb 28, 2024
761b616
Refactor starcoder2.py: Update model arguments and modules
Muhtasham Feb 29, 2024
9929751
Merge branch 'ml-explore:main' into add/sc2
Muhtasham Feb 29, 2024
8aee3b7
Fix LM head and MLP layers
Muhtasham Mar 1, 2024
a9ba4b3
Rename input layer norm
Muhtasham Mar 1, 2024
1366c03
Update bias in linear layers
Muhtasham Mar 1, 2024
446e7e9
Refactor token embeddings in Starcoder2Model
Muhtasham Mar 1, 2024
f72792c
Rename to standard HF attention layer name
Muhtasham Mar 1, 2024
a8ce255
Add LayerNorm
Muhtasham Mar 1, 2024
4c9aea6
Add transposed token embeddings (like in Gemma)
Muhtasham Mar 1, 2024
83d4cb5
Merge branch 'ml-explore:main' into add/sc2
Muhtasham Mar 1, 2024
c954406
Refactor MLP and TransformerBlock classes
Muhtasham Mar 1, 2024
512a542
Add tie_word_embeddings option to ModelArgs and update Model implemen…
Muhtasham Mar 1, 2024
3a81505
Add conditional check for tying word embeddings in Starcoder2Model
Muhtasham Mar 1, 2024
44d920a
Merge branch 'ml-explore:main' into add/sc2
Muhtasham Mar 1, 2024
ab562b1
Fix bias in lm_head linear layer
Muhtasham Mar 1, 2024
f268ab0
Remove unused LayerNorm in stablelm
Muhtasham Mar 1, 2024
fe6c52f
Update transformers dependency to use GitHub repository
Muhtasham Mar 1, 2024
b21a6bf
fix lm head bug, revert transformer req
awni Mar 1, 2024
fc31d7a
Update RoPE initialization in Attention class
Muhtasham Mar 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ You can convert models in the Python API with:
```python
from mlx_lm import convert

upload_repo = "mistralai/Mistral-7B-Instruct-v0.1"
upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit"
Muhtasham marked this conversation as resolved.
Show resolved Hide resolved

convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo)
```
Expand Down
182 changes: 182 additions & 0 deletions llms/mlx_lm/models/starcoder2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs
from .layers import LayerNorm


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int = None
max_position_embeddings: int = 16384
norm_eps: float = None
rms_norm_eps: float = 1e-5
norm_type: str = "layer_norm"
vocab_size: int = 49152
rope_theta: float = 100000

def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads

if self.norm_eps is None:
self.norm_eps = self.rms_norm_eps


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args

dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads

self.repeats = self.n_heads // self.n_kv_heads

head_dim = args.hidden_size // args.num_attention_heads
self.scale = head_dim**-0.5

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape

queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])

if self.repeats > 1:
keys, values = map(repeat, (keys, values))

if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)

scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)


class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.c_fc = nn.Linear(dim, hidden_dim, bias=True)
self.c_proj = nn.Linear(hidden_dim, dim, bias=True)

def __call__(self, x):
return self.c_proj(nn.gelu(self.c_fc(x)))


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.n_heads = args.num_attention_heads

self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = LayerNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
Muhtasham marked this conversation as resolved.
Show resolved Hide resolved
return out, cache


class Starcoder2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps)

def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)

mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)

if cache is None:
cache = [None] * len(self.layers)

for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])

return self.norm(h), cache


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = Starcoder2Model(args)

Muhtasham marked this conversation as resolved.
Show resolved Hide resolved
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T
return out, cache

@property
def layers(self):
return self.model.layers
1 change: 1 addition & 0 deletions llms/mlx_lm/tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def check_lora_layers(num_model):
"stablelm_epoch",
"qwen2",
"gemma",
"starcoder2",
]:
check_lora_layers(len(model.model.layers))

Expand Down