Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generation speed #15

Open
sh0416 opened this issue Oct 18, 2023 · 10 comments
Open

Generation speed #15

sh0416 opened this issue Oct 18, 2023 · 10 comments

Comments

@sh0416
Copy link

sh0416 commented Oct 18, 2023

I am wondering whether handling the KV cache make your code slow compared to huggingface generate.

I also implemented it similarly, but updating the kv cache is slow, so the generation speed is not that great as I expected.

@ayaka14732
Copy link
Owner

Yes, I benchmarked and it is 14x slower than the Hugging Face implementation. However, the previous version (without KV Cache) is slow too. I think it is because of the arrays are padded to a certain length in my implementation. I will implement a new version without padding soon.

@sh0416
Copy link
Author

sh0416 commented Oct 18, 2023

How could it be in the jitted function???

I really want to know that..

I also read scaling transformer inference paper from google, but the implementation is rather hard to understand..

@tututu05
Copy link

same question!my jax implementation of llama is slower than pytorch,do you have any way to speed up?

@sh0416
Copy link
Author

sh0416 commented Oct 23, 2023

I gave it up. I do lots of things, but none of them works except directly applying cuda function.

@ayaka14732
Copy link
Owner

How could it be in the jitted function???

@sh0416 The Hugging Face implementation is not meant to be compiled. See huggingface/transformers#24587 (comment)

@sh0416
Copy link
Author

sh0416 commented Oct 23, 2023

I mean the jitted implementation without padding in JAX.

@ayaka14732
Copy link
Owner

@sh0416 Yes, they just don't jit it.

@sh0416
Copy link
Author

sh0416 commented Oct 23, 2023

do you think the Pytorch implementation would be faster if jit compilation is applied to their implementation?

@ayaka14732
Copy link
Owner

I don't know much about the jit compilation in PyTorch. I am going to implement a new version of the generation without padding. It will not be jitted, or just jit some parts of it.

@sh0416
Copy link
Author

sh0416 commented Oct 23, 2023

FYI, I did it. I don't know how you implement it, but the core idea of my implementation is to create multiple jitted function with various sequence length multiple of 64. It seems improved, but not as fast as pytorch version.

Also, some weird implementation details lie in the Pytorch as they do not show OOM when using batch size and sequence length that must be shown in theory.
link: https://kipp.ly/transformer-inference-arithmetic/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants