You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Following the issues 252 and this, practically GPT-J model became unavailable on colab TPU with TPU_driver0.1 anymore. However, as default, GPT-J crashes on other drivers including 0.2 and nightly.
Is there any way to use 0.2 or nightly driver? Otherwise, it means GPT-J is ended on TPU inference.
The text was updated successfully, but these errors were encountered:
mosmos6
changed the title
How to infer GPT-J on TPU_driver0.2 or nightly?
How to infer with GPT-J on TPU_driver0.2 or nightly?
Mar 17, 2023
I resolved this by my own so I'm sharing the modified mesh-transformer-jax with everyone.
Background;
In early March 2023, Google removed TPU_driver0.1 from colab. The original GPT-J strictly requires JAX 0.2.12 so it could not be inferred with on colab anymore because TPU_driver0.2 needs newer jax.
Takeaway;
I added some modifications to mesh_transformer folder and colab demo together with the updated requirements. Thus you can infer with this on colab now. You can continue to use the same (slim) weights as before.
How;
I uploaded the relevant file and folder here. You can extract mesh_transformer folder and requirement.txt file, replace them with the originals in your own repo, and use GPT-J inference on TPU_driver0.2.ipynb to infer with.
Important notes;
Sorry, you'll need pro or pro+ subscription of colab because it requires high memory TPU runtime.
I have not checked it for finetuning on TPU VM yet. This can cause errors during a process. I'm planning to cover it next month. Until then, possibly you must add further modifications to xmap by yourself or downgrade to jax 0.2.18 or 0.2.20.
You can also infer with GPT-J by device_serve.py on TPU VM, but you can't use the original file. If anyone wants it, please post a request from issue.
Following the issues 252 and this, practically GPT-J model became unavailable on colab TPU with TPU_driver0.1 anymore. However, as default, GPT-J crashes on other drivers including 0.2 and nightly.
Is there any way to use 0.2 or nightly driver? Otherwise, it means GPT-J is ended on TPU inference.
The text was updated successfully, but these errors were encountered: