Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems sharding Llama-70B on TPU v3-32 #22

Open
divyapatel4 opened this issue Jan 5, 2024 · 1 comment
Open

Problems sharding Llama-70B on TPU v3-32 #22

divyapatel4 opened this issue Jan 5, 2024 · 1 comment

Comments

@divyapatel4
Copy link

divyapatel4 commented Jan 5, 2024

Hi Ayaka,
We are currently utilizing your Llama-70B implementation for generation on Cloud TPUs and have encountered a few challenges that we hope you might be able to assist us with. We experienced memory issues when attempting to convert the model to JAX format on the Cloud TPUs as that ran out of memory while converting. We managed to convert the model using a swap memory of 400GB through an attached disk (SSD). We are attaching a disk with the pre-converted model to all the hosts in TPU v3-32 in read-only mode.

When we tried to shard the 70B model across the TPUs, we ran out of TPU HBM. We also noticed that when running smaller models like Llama-13B, redundant responses were generated from all four hosts in the TPU slice (TPUv3-32). We would greatly appreciate any guidance you could provide on generating with Llama-70B on TPU v3-32, or any alternative methods for generation using a single host TPUv3-8.

We would like to express our gratitude for your exceptional repository. It has significantly accelerated our research. The speed of generation on these TPUs using your implementation, compared to GPUs, is remarkable! We plan to acknowledge your valuable contribution in our upcoming paper. Thank you once again for your outstanding work.

@ayaka14732
Copy link
Owner

Hi @divyapatel4, I am busy with other matters in January, so I may have little time to look into this issue. Have you tried the new Llama JAX implementation in the Hugging Face transformers library, and does that work for you?

Thank you for supporting this library!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants