Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Fix embeddings extraction for all models #291

Merged

Conversation

skirodev
Copy link
Contributor

@skirodev skirodev commented Jun 2, 2023

Apply the fixed code for embeddings extraction to all models to avoid assertion errors. #288

@LLukas22
Copy link
Contributor

LLukas22 commented Jun 2, 2023

Correct me if im wrong, until now we passed the output of the embedding layer as the model embedding?

@philpax
Copy link
Collaborator

philpax commented Jun 2, 2023

Hmm yeah, I think this is correct (the embeddings should come from just before the LM head is evaluated to coalesce the LM output into logits). I'm on holiday so I can't check, but I assume embd is the input embeddings, not the output.

Happy to merge if you can validate my understanding.

@LLukas22
Copy link
Contributor

LLukas22 commented Jun 2, 2023

In my oppinion this is also correct, but i'll test it when i get home. If it works i'll merge it.

@skirodev
Copy link
Contributor Author

skirodev commented Jun 2, 2023

@LLukas22 Yes, exactly, we obtain the vector corresponding to the last token as the embeddings for the whole sentence. The length of this embeddings is a fixed value equal to n_embd. For the LLaMA 7B model, its length is 4096, while for OpenAI's text-embedding-ada-002 model, its length is 1536.

@LLukas22
Copy link
Contributor

LLukas22 commented Jun 2, 2023

Good catch, when i get home from Work i'll test it against the rustformers HF models, but i'm 99% sure it will work.

@Jeadie
Copy link

Jeadie commented Jun 2, 2023

Confirmed that this is a fix for #288

@LLukas22
Copy link
Contributor

LLukas22 commented Jun 2, 2023

This works, but for gpt-j and gpt-neox based models im getting a process didn't exit successfully: target\release\examples\embeddings.exe gptj C:\Users\Lu.Kreuss\Downloads\gpt-j-6b-q4_0-ggjt.bin -r EleutherAI/gpt-j-6b (exit code: 0xc0000005, STATUS_ACCESS_VIOLATION) error. Tested with rustformers/gpt-j-ggml and rustformers/redpajama-ggml.

Im also getting very poor embeddings this way, compared to some BERT models i have lying around. Maybe we should perform some sort of pooling on the embeddings of all tokens. The SGPT paper uses weighted mean pooling, where more recent tokens have a stronger impact on the produced embedding than older tokens. Maybe this would improve the quality of the embeddings?

@skirodev
Copy link
Contributor Author

skirodev commented Jun 3, 2023

This works, but for gpt-j and gpt-neox based models im getting a process didn't exit successfully: target\release\examples\embeddings.exe gptj C:\Users\Lu.Kreuss\Downloads\gpt-j-6b-q4_0-ggjt.bin -r EleutherAI/gpt-j-6b (exit code: 0xc0000005, STATUS_ACCESS_VIOLATION) error. Tested with rustformers/gpt-j-ggml and rustformers/redpajama-ggml.

This is because the value of "context_size" set in the model parameters was too large, so I rewrote the embeddings example.

Im also getting very poor embeddings this way, compared to some BERT models i have lying around. Maybe we should perform some sort of pooling on the embeddings of all tokens. The SGPT paper uses weighted mean pooling, where more recent tokens have a stronger impact on the produced embedding than older tokens. Maybe this would improve the quality of the embeddings?

This might be a viable approach, but I am not sure if it is necessary to use a model with a large number of parameters to generate embeddings?

@LLukas22
Copy link
Contributor

LLukas22 commented Jun 3, 2023

This is because the value of "context_size" set in the model parameters was too large, so I rewrote the embeddings example.

o_O I feel kinda stupid for missing this, good job now everything works as expected 👍

This might be a viable approach, but I am not sure if it is necessary to use a model with a large number of parameters to generate embeddings?

It makes sense if you don't want to load a second smaller model to perform the embedding task. But i'm gooing to create another issue for this. It's not part of this PR.

@LLukas22 LLukas22 merged commit e52a102 into rustformers:main Jun 3, 2023
@hhamud hhamud mentioned this pull request Aug 7, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants