Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Starcoder 2 #502

Merged
merged 25 commits into from
Mar 3, 2024
Merged

Add Starcoder 2 #502

merged 25 commits into from
Mar 3, 2024

Conversation

Muhtasham
Copy link
Contributor

Adding new code models that dropped

merged to Transformers

@awni
Copy link
Member

awni commented Feb 28, 2024

This looks nice. Is it working?

There is an open issue btw #500, not sure if @lazarust made any progress? Maybe this is a good PR to collaborate on, though it looks like it is basically done?

@lazarust
Copy link
Contributor

@awni I haven't had time to make progress yet, but this solution looks like it's the right direction

@Muhtasham
Copy link
Contributor Author

Is it working?

@awni not yet hf is down 😞 but when trying to generate locally despite files being there and import working getting

ValueError: Model type starcoder2 not supported. Error: No module named 'mlx_lm.models.starcoder2'

but when I do in python shell I do not get that error but hf down error, is this behaviour intended

@lazarust happy to collaborate, this is still draft based on original mistral implementation I think we need some more changes from huggingface/transformers#29120

@awni
Copy link
Member

awni commented Feb 29, 2024

but when I do in python shell I do not get that error but hf down error, is this behaviour intended

No it's not. Make sure you do not have another MLX LM installed (pip uninstall mlx-lm). Also make sure you have an editable install pip install -e .

@Muhtasham
Copy link
Contributor Author

qq @awni is gelu_pytorch_tanh equivalnet in mlx is

    def __call__(self, x) -> mx.array:
        return self.w2(nn.gelu(self.w1(x)) * self.w3(x))

@angeloskath
Copy link
Member

@Muhtasham the tanh approximation in PyTorch is not strictly the same as nn.gelu but it is close to several significant digits so it should not cause any issues. See the plot at ml-explore/mlx#744 (comment) for the difference between PyTorch's Tanh approximation and nn.gelu .

@awni
Copy link
Member

awni commented Mar 1, 2024

How is this going? Were you able to generate any sensible code yet?

@Muhtasham
Copy link
Contributor Author

Having some,

raise ValueError(f"Received parameters not in model: {extras}.")

I went through model config nothing seems extra any hints @awni ?

command

python -m mlx_lm.convert \
    --hf-path bigcode/starcoder2-3b \
    -q \
    --upload-repo mlx-community/starcoder2-3b-4bit

@mzbac
Copy link
Contributor

mzbac commented Mar 1, 2024

I suggest either referring to the transformers implementation if you are familiar with the codebase, or loading the model weights and checking the weight names which will give you a hint of the model structure.

https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py#L1070

@mzbac
Copy link
Contributor

mzbac commented Mar 1, 2024

And please note that the startcoder2 seems to have enabled the bias weight for all the linear layers, so you may need to enable it in nn.Linear.

@awni
Copy link
Member

awni commented Mar 1, 2024

Thanks for the help here @mzbac !! @Muhtasham after those fixes is it working?

@mzbac
Copy link
Contributor

mzbac commented Mar 2, 2024

By looking at the tokenizer config, it specifies token to allow the model to fill in the missing code in the middle of the code snippet, something like this prompt = <fim_prefix>${textAboveCursor}<fim_suffix>${textBelowCursor}<fim_middle>.

I'm not sure i follow how you generate just the middle portion.. I guess I should look at the paper for the details.

Yeah, basically it's just a different way to structure the prompt and allow the model to autocomplete for middle portion. A more practical example is similar to using GitHub Copilot, where the model will fill in the content at the cursor position.

@awni
Copy link
Member

awni commented Mar 2, 2024

Ok, I'm a little confused as to what to do with this. The model doesn't generate sensible outputs, so is the plan to land it so people can use LoRA / fine-tunes? Or do we need to fix the prompt? Or is there some other thing that I'm missing?

@mzbac
Copy link
Contributor

mzbac commented Mar 2, 2024

Ok, I'm a little confused as to what to do with this. The model doesn't generate sensible outputs, so is the plan to land it so people can use LoRA / fine-tunes? Or do we need to fix the prompt? Or is there some other thing that I'm missing?

I will do some testing tonight. In terms of fine-tuning, my understanding is that the FIM model is just different in training data prompt format and not different from normal fine-tuning. However, I have not done any FIM fine-tuning so I may be wrong.

@mzbac
Copy link
Contributor

mzbac commented Mar 2, 2024

I did some quick tests and it looks like there are issues in the mlx implementation. I couldn't figure out exactly where, but when you use the transformers example as shown below, you can see that the model generates correct FIM code. However, this doesn't work in the mlx implementation.

# pip install git+https://github.com/huggingface/transformers.git # TODO: merge PR to main
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "bigcode/starcoder2-3b"
device = "cuda" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

tokenizer.eos_token = "<file_sep>"
inputs = tokenizer.encode("<fim_prefix>\ndef quicksort(arr):\n<fim_suffix>\n<fim_middle>\n", return_tensors="pt").to(device)
outputs = model.generate(input_ids=inputs,
    temperature=0, 
    max_new_tokens=100,
)
print(tokenizer.decode(outputs[0]))

output:

<fim_prefix>
def quicksort(arr):
<fim_suffix>
<fim_middle>
        if len(arr) <= 1:
                return arr
        else:
                pivot = arr[0]
                less = [i for i in arr[1:] if i <= pivot]
                greater = [i for i in arr[1:] if i > pivot]
                return quicksort(less) + [pivot] + quicksort(greater)

print(quicksort([10, 5, 2, 3]))<file_sep>

@awni
Copy link
Member

awni commented Mar 2, 2024

However, this doesn't work in the mlx implementation

You tried it with the correct prompt right? Do we add the FIM prompt by default or is that something you have to do manually at the moment?

@mzbac
Copy link
Contributor

mzbac commented Mar 2, 2024

However, this doesn't work in the mlx implementation

You tried it with the correct prompt right? Do we add the FIM prompt by default or is that something you have to do manually at the moment?

I used python -m mlx_lm.generate --model bigcode/starcoder2-3b --prompt "<fim_prefix>\ndef quicksort(arr):\n<fim_suffix>\n<fim_middle>\n" --temp 0.0 --eos-token "<file_sep>" --ignore-chat-template, it should be the exact prompt used in the transformer implementation.

@mzbac
Copy link
Contributor

mzbac commented Mar 3, 2024

Based on @Blaizzy's PR (#518), the instruction prompt failed because it was using traditional Rope. Changing to traditional = False should fix it. However, I have tried his implementation and it still doesn't work with FIM prompt on my local machine. I suspect there may be something wrong with the Rope implementation or configuration.

Edit:
Just found the FIM works if change the prompt to:
python -m mlx_lm.generate --model bigcode/starcoder2-3b --prompt "<file_sep>quick_sort.py<file_sep><fim_prefix>\ndef quicksort(arr):\n<fim_suffix>\n<fim_middle>" --temp 0.0 --eos-token "<file_sep>" --ignore-chat-template
Not very sure why the transformer's implementation doesn't need to be so strict with the prompt though.

@awni
Copy link
Member

awni commented Mar 3, 2024

Ok just to make sure I understand - the only difference between this PR and #518 is that the RoPE traditional is False?

@awni
Copy link
Member

awni commented Mar 3, 2024

If that's the case, let's simply update that here? It's a one line change. I think it makes more sense than starting on a separate PR?

@awni
Copy link
Member

awni commented Mar 3, 2024

Can confirm, seems to work now! Thanks!

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition and contributions everyone!

@awni awni merged commit 81e2a80 into ml-explore:main Mar 3, 2024
3 checks passed
@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 3, 2024

Based on @Blaizzy's PR (#518),
Can confirm, seems to work now! Thanks!
Thanks for the addition and contributions everyone!

Most welcome 😊!

@mzbac
Copy link
Contributor

mzbac commented Mar 3, 2024

Ok just to make sure I understand - the only difference between this PR and #518 is that the RoPE traditional is False?

@awni, there are some minor model args that need to be cleaned up. For example, the use of rms_norm_eps should be changed to norm_epsilon to match Starcoder2's configuration. Also, it would be more efficient to use the repeat gqa instead of mapping concatenate. I'm not sure if @Blaizzy would like to make a following PR for it; otherwise, I don't mind creating a patch PR to clean it up.

@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 3, 2024

Ok just to make sure I understand - the only difference between this PR and #518 is that the RoPE traditional is False?

Didn't notice the last update on this PR early Saturday.

I'm happy I could help make it work for everyone 🚀!

@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 3, 2024

@mzbac I can do that :)

@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 3, 2024

Also, it would be more efficient to use the repeat header instead of mapping concatenate.

Could you elaborate on this?

@mzbac
Copy link
Contributor

mzbac commented Mar 3, 2024

Also, it would be more efficient to use the repeat header instead of mapping concatenate.

Could you elaborate on this?

Yeah, in the current implementation here https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/starcoder2.py#L71-L76, it would be updated to something like : https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py#L83-L85.

@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 3, 2024

I see, thanks for clarifying! Will do.

I thought so too because I used the llama's repeat.

If I may ask, why was it done differently?

@mzbac
Copy link
Contributor

mzbac commented Mar 3, 2024

Yeah, based on the previous PR, simply repeating is faster than concatenating. Just FYI: #443

@sislam-provenir
Copy link

This feature came at the perfect time for me! 24 hours ago there was no support and now there is. Love open source! 🩶

@sislam-provenir
Copy link

sislam-provenir commented Mar 3, 2024

So I'm trying to fine-tune StarCoder2 using QLoRA.

$ pwd
>>> /.../mlx-examples/llms

When I issue the fine-tuning command, I get an AttributeError:

$ python -m mlx_lm.lora \                                           
    --model mlx-community/starcoder2-3b-4bit \
    --train \
    --data $(realpath ../lora/data) \
    --iters 10


>>> None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Loading pretrained model
Fetching 7 files: 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 61422.86it/s]
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/sameenislam/source/sameenislam/mlx-examples/llms/mlx_lm/lora.py", line 246, in <module>
    run(args)
  File "/Users/sameenislam/source/sameenislam/mlx-examples/llms/mlx_lm/lora.py", line 172, in run
    linear_to_lora_layers(model, args.lora_layers)
  File "/Users/sameenislam/source/sameenislam/mlx-examples/llms/mlx_lm/tuner/utils.py", line 27, in linear_to_lora_layers
    if model.model_type in [
       ^^^^^^^^^^^^^^^^
  File "/Users/sameenislam/anaconda3/lib/python3.11/site-packages/mlx/nn/layers/base.py", line 137, in __getattr__
    super(Module, self).__getattr__(key, val)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'super' object has no attribute '__getattr__'. Did you mean: '__setattr__'?

Has anyone encountered this?

Also including this to show that inference is working:

$ python -m mlx_lm.generate --model mlx-community/starcoder2-3b-4bit --prompt "hello"

>>> None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Fetching 7 files: 100%|██████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 145347.17it/s]
==========
Prompt: hello
_1(void) {
	printf("hello, world\n");
}

void helloworld_print_double_1(double a) {
	printf("%f\n", a);
}

double helloworld_square_1(double a) {
	return a * a;
}

double helloworld_square_2(double a) {
	return a * a;
}

double helloworld_square_
==========
Prompt: 16.348 tokens-per-sec
Generation: 40.031 tokens-per-sec

@mzbac
Copy link
Contributor

mzbac commented Mar 3, 2024

So I'm trying to fine-tune StarCoder2 using QLoRA.

$ pwd
>>> /.../mlx-examples/llms

When I issue the fine-tuning command, I get an AttributeError:

$ python -m mlx_lm.lora \                                           
    --model mlx-community/starcoder2-3b-4bit \
    --train \
    --data $(realpath ../lora/data) \
    --iters 10


>>> None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Loading pretrained model
Fetching 7 files: 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 61422.86it/s]
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/sameenislam/source/sameenislam/mlx-examples/llms/mlx_lm/lora.py", line 246, in <module>
    run(args)
  File "/Users/sameenislam/source/sameenislam/mlx-examples/llms/mlx_lm/lora.py", line 172, in run
    linear_to_lora_layers(model, args.lora_layers)
  File "/Users/sameenislam/source/sameenislam/mlx-examples/llms/mlx_lm/tuner/utils.py", line 27, in linear_to_lora_layers
    if model.model_type in [
       ^^^^^^^^^^^^^^^^
  File "/Users/sameenislam/anaconda3/lib/python3.11/site-packages/mlx/nn/layers/base.py", line 137, in __getattr__
    super(Module, self).__getattr__(key, val)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'super' object has no attribute '__getattr__'. Did you mean: '__setattr__'?

Has anyone encountered this?

Also including this to show that inference is working:

$ python -m mlx_lm.generate --model mlx-community/starcoder2-3b-4bit --prompt "hello"

>>> None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Fetching 7 files: 100%|██████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 145347.17it/s]
==========
Prompt: hello
_1(void) {
	printf("hello, world\n");
}

void helloworld_print_double_1(double a) {
	printf("%f\n", a);
}

double helloworld_square_1(double a) {
	return a * a;
}

double helloworld_square_2(double a) {
	return a * a;
}

double helloworld_square_
==========
Prompt: 16.348 tokens-per-sec
Generation: 40.031 tokens-per-sec

There is a missing model_type in the starcoder2. You can try adding it to your local code as shown in this PR.

@sislam-provenir
Copy link

There is a missing model_type in the starcoder2. You can try adding it to your local code as shown in this PR.

Awesome, I've made the change locally from PR #522 and it's working like a charm!

$ python -m mlx_lm.lora \ 
    --model $(realpath mlx_model) \
    --train \
    --data $(realpath ../lora/data) \
    --iters 10

>>> None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Loading pretrained model
Total parameters 602.983M
Trainable parameters 1.212M
Loading datasets
Training
Starting training..., iters: 10
Iter 1: Val loss 2.363, Val took 30.839s
Iter 10: Train loss 2.274, Learning Rate 1.000e-05, It/sec 0.490, Tokens/sec 196.036, Trained Tokens 3999
Saved final adapter weights to adapters.npz.

P.S. $(realpath mlx_model) is the model obtained with:

$ python -m mlx_lm.convert \                                        
    --hf-path bigcode/starcoder2-3b \
    -q

devonthomas35 pushed a commit to devonthomas35/mlx-examples that referenced this pull request Mar 11, 2024
* Add Starcoder2 model and update utils.py

* Refactor model arguments and modules in starcoder2.py

* Refactor FeedForward class to MLP in starcoder2.py

* Fix typo

* pre-commit

* Refactor starcoder2.py: Update model arguments and modules

* Fix LM head and MLP layers

* Rename  input layer norm

* Update bias in linear layers

* Refactor token embeddings in Starcoder2Model

* Rename to standard HF attention layer name

* Add LayerNorm

* Add transposed token embeddings (like in Gemma)

* Refactor MLP and TransformerBlock classes

* Add tie_word_embeddings option to ModelArgs and update Model implementation

* Add conditional check for tying word embeddings in Starcoder2Model

* Fix bias in lm_head linear layer

* Remove unused LayerNorm in stablelm

* Update transformers dependency to use GitHub repository

* fix lm head bug, revert transformer req

* Update RoPE initialization in Attention class

---------

Co-authored-by: Awni Hannun <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants