-
Notifications
You must be signed in to change notification settings - Fork 369
Fix embeddings extraction for all models #291
Fix embeddings extraction for all models #291
Conversation
Correct me if im wrong, until now we passed the output of the embedding layer as the model embedding? |
Hmm yeah, I think this is correct (the embeddings should come from just before the LM head is evaluated to coalesce the LM output into logits). I'm on holiday so I can't check, but I assume Happy to merge if you can validate my understanding. |
In my oppinion this is also correct, but i'll test it when i get home. If it works i'll merge it. |
@LLukas22 Yes, exactly, we obtain the vector corresponding to the last token as the embeddings for the whole sentence. The length of this embeddings is a fixed value equal to n_embd. For the LLaMA 7B model, its length is 4096, while for OpenAI's |
Good catch, when i get home from Work i'll test it against the rustformers HF models, but i'm 99% sure it will work. |
Confirmed that this is a fix for #288 |
This works, but for gpt-j and gpt-neox based models im getting a Im also getting very poor embeddings this way, compared to some BERT models i have lying around. Maybe we should perform some sort of pooling on the embeddings of all tokens. The SGPT paper uses weighted mean pooling, where more recent tokens have a stronger impact on the produced embedding than older tokens. Maybe this would improve the quality of the embeddings? |
This is because the value of "context_size" set in the model parameters was too large, so I rewrote the embeddings example.
This might be a viable approach, but I am not sure if it is necessary to use a model with a large number of parameters to generate embeddings? |
o_O I feel kinda stupid for missing this, good job now everything works as expected 👍
It makes sense if you don't want to load a second smaller model to perform the embedding task. But i'm gooing to create another issue for this. It's not part of this PR. |
Apply the fixed code for embeddings extraction to all models to avoid assertion errors. #288