-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generation speed #15
Comments
Yes, I benchmarked and it is 14x slower than the Hugging Face implementation. However, the previous version (without KV Cache) is slow too. I think it is because of the arrays are padded to a certain length in my implementation. I will implement a new version without padding soon. |
How could it be in the jitted function??? I really want to know that.. I also read scaling transformer inference paper from google, but the implementation is rather hard to understand.. |
same question!my jax implementation of llama is slower than pytorch,do you have any way to speed up? |
I gave it up. I do lots of things, but none of them works except directly applying cuda function. |
@sh0416 The Hugging Face implementation is not meant to be compiled. See huggingface/transformers#24587 (comment) |
I mean the jitted implementation without padding in JAX. |
@sh0416 Yes, they just don't jit it. |
do you think the Pytorch implementation would be faster if jit compilation is applied to their implementation? |
I don't know much about the jit compilation in PyTorch. I am going to implement a new version of the generation without padding. It will not be jitted, or just jit some parts of it. |
FYI, I did it. I don't know how you implement it, but the core idea of my implementation is to create multiple jitted function with various sequence length multiple of 64. It seems improved, but not as fast as pytorch version. Also, some weird implementation details lie in the Pytorch as they do not show OOM when using batch size and sequence length that must be shown in theory. |
I am wondering whether handling the KV cache make your code slow compared to huggingface generate.
I also implemented it similarly, but updating the kv cache is slow, so the generation speed is not that great as I expected.
The text was updated successfully, but these errors were encountered: