Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to infer with GPT-J on TPU_driver0.2 or nightly? #256

Closed
mosmos6 opened this issue Mar 17, 2023 · 1 comment
Closed

How to infer with GPT-J on TPU_driver0.2 or nightly? #256

mosmos6 opened this issue Mar 17, 2023 · 1 comment

Comments

@mosmos6
Copy link

mosmos6 commented Mar 17, 2023

Following the issues 252 and this, practically GPT-J model became unavailable on colab TPU with TPU_driver0.1 anymore. However, as default, GPT-J crashes on other drivers including 0.2 and nightly.
Is there any way to use 0.2 or nightly driver? Otherwise, it means GPT-J is ended on TPU inference.

@mosmos6 mosmos6 changed the title How to infer GPT-J on TPU_driver0.2 or nightly? How to infer with GPT-J on TPU_driver0.2 or nightly? Mar 17, 2023
@mosmos6
Copy link
Author

mosmos6 commented Apr 13, 2023

I resolved this by my own so I'm sharing the modified mesh-transformer-jax with everyone.

Background;
In early March 2023, Google removed TPU_driver0.1 from colab. The original GPT-J strictly requires JAX 0.2.12 so it could not be inferred with on colab anymore because TPU_driver0.2 needs newer jax.

Takeaway;
I added some modifications to mesh_transformer folder and colab demo together with the updated requirements. Thus you can infer with this on colab now. You can continue to use the same (slim) weights as before.

How;
I uploaded the relevant file and folder here. You can extract mesh_transformer folder and requirement.txt file, replace them with the originals in your own repo, and use GPT-J inference on TPU_driver0.2.ipynb to infer with.

Important notes;

  1. Sorry, you'll need pro or pro+ subscription of colab because it requires high memory TPU runtime.

  2. I have not checked it for finetuning on TPU VM yet. This can cause errors during a process. I'm planning to cover it next month. Until then, possibly you must add further modifications to xmap by yourself or downgrade to jax 0.2.18 or 0.2.20.

  3. You can also infer with GPT-J by device_serve.py on TPU VM, but you can't use the original file. If anyone wants it, please post a request from issue.

Screenshot 2023-04-13 122413

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant