Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Requested backend tpu_driver, but it failed to initialize: DEADLINE_EXCEEDED using 0.1 drivers since 10/02/2023 #3405

Closed
henk717 opened this issue Feb 11, 2023 · 92 comments

Comments

@henk717
Copy link

henk717 commented Feb 11, 2023

Describe the current behavior
When running an older version of JAX, the TPU receives the following error:
Traceback (most recent call last):
File "aiserver.py", line 10214, in
load_model(initial_load=True)
File "aiserver.py", line 2806, in load_model
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig)
File "/content/KoboldAI-Client/tpu_mtj_backend.py", line 1194, in load_model
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py", line 314, in devices
return get_backend(backend).devices()
File "/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py", line 258, in get_backend
return _get_backend_uncached(platform)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py", line 248, in _get_backend_uncached
raise RuntimeError(f"Requested backend {platform}, but it failed "
RuntimeError: Requested backend tpu_driver, but it failed to initialize: DEADLINE_EXCEEDED: Failed to connect to remote server at address: grpc://10.106.231.74:8470. Error from gRPC: Deadline Exceeded. Details:

This happens for all users of the notebook on Colab, while Kaggle is still working as intended.

Describe the expected behavior
Jax is correctly able to connect to the TPU and can then proceed with loading the user defined model.

What web browser you are using
This issue does not depend on a browser, but for completeness I am using an up to date Microsoft Edge.

Additional context
Here is an example of an effected notebook:

import os
if not os.path.exists("/content/drive"):
  os.mkdir("/content/drive")
if not os.path.exists("/content/drive/MyDrive/"):
  os.mkdir("/content/drive/MyDrive/")

!wget https://koboldai.org/ckds -O - | bash /dev/stdin --model EleutherAI/gpt-neox-20b

The relevant backend code can be found here : https://github.com/KoboldAI/KoboldAI-Client/blob/main/tpu_mtj_backend.py
This also makes use of a heavily modified MTJ with the following relevant dependencies:
jax == 0.2.21
jaxlib >= 0.1.69, <= 0.3.7
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck

MTJ uses tpu_driver0.1_dev20210607

@henk717 henk717 added the bug label Feb 11, 2023
@henk717 henk717 changed the title TPU fails to connect since 10/02/2023 Requested backend tpu_driver, but it failed to initialize: DEADLINE_EXCEEDED since 10/02/2023 Feb 11, 2023
@mosmos6
Copy link

mosmos6 commented Feb 11, 2023

I'm encountering the same issue when loading GPT-J. It was working fine until 24 hours ago approximately.

@henk717
Copy link
Author

henk717 commented Feb 11, 2023

I'm encountering the same issue when loading GPT-J. It was working fine until 24 hours ago approximately.

Our code is based on MTJ which the original GPT-J runs on top off and happens prior to loading the model, both use the older V1 implementation of the model.
So its probable this effects all MTJ users.

@mosmos6
Copy link

mosmos6 commented Feb 11, 2023

Mine is trying to connect grpc://10.63.28.250:8470 and errors so it's pretty much everywhere. Furthermore, it's also taking unusually long to collect pathy and uvicorn, etc..

@henk717
Copy link
Author

henk717 commented Feb 11, 2023

I have pinpointed the issue down to the driver version the projects use. It looks like the older ones are no longer working.
For example tpu_driver0.1_dev20210607 is being used in our project, when paired with the following code you get the error:

!pip install jax jaxlib 

import requests
import os
import jax

from jax.config import config

print("Connecting to your Colab instance's TPU", flush=True)
if os.environ.get('COLAB_TPU_ADDR', '') != '':
    tpu_address = os.environ['COLAB_TPU_ADDR']  # Colab
else:
    tpu_address = os.environ['TPU_NAME']  # Kaggle
tpu_address = tpu_address.replace("grpc://", "")
tpu_address_without_port = tpu_address.split(':', 1)[0]
url = f'http://{tpu_address_without_port}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
print()

jax.devices()

I can't find a list of all available drivers, but collected 3 from bug reports and other colabs.

tpu_driver0.1_dev20210607 is used by us and produces the error, tpu_driver0.1-dev20211030 is newer and used by some examples where people recommend not to use the nightly, this also produces the error.

tpu_driver_20221011 is being used by some stable diffusion colabs and that one works in my example above. But unfortunately does not work with our MTJ notebook.

If someone knows a list of long term supported drivers I could test more of them and see if this fixes the issue for MTJ. Otherwise i'd like to politely request that the commonly used older drivers are restored in functionality. GPT-J and MTJ are still widely used but rely on older driver versions.

Update: Seems to effect all the 0.1 drivers.

@mosmos6
Copy link

mosmos6 commented Feb 11, 2023

GPT-J doesn't work with tpu_driver_20221011 either.

@henk717
Copy link
Author

henk717 commented Feb 11, 2023

GPT-J won't work with that indeed, but it does make a difference between connecting to the TPU and getting the deadline errors.
We will have to wait for the Google engineers to fix the 0.1 drivers we depend upon, for the time being Kaggle still works so if you have something urgent that can be done on Kaggle I recommend checking there until they have some time to fix it.

@henk717 henk717 changed the title Requested backend tpu_driver, but it failed to initialize: DEADLINE_EXCEEDED since 10/02/2023 Requested backend tpu_driver, but it failed to initialize: DEADLINE_EXCEEDED using 0.1 drivers since 10/02/2023 Feb 11, 2023
@mosmos6
Copy link

mosmos6 commented Feb 11, 2023

Thank you @henk717 I heavily use GPT-J everyday for work so I'll need it running from Monday morning. I hope this is a temporary issue.

@henk717
Copy link
Author

henk717 commented Feb 11, 2023

I hope so to, but breaking the entire 0.1 driver ecosystem does not sound like the thing they did on purpose and won't be interested in fixing before this gets installed on things like Kaggle and Google Compute.

My theory is the TPUv2 firmware update that causes this either has been spread everywhere and the TPUv3 is unaffected, or they used Colab as a testing ground to see if people would run into issues and we are the first to notice because we rely on a dependency from the 2021 TPU era.

@mosmos6
Copy link

mosmos6 commented Feb 11, 2023

Is there a way to infer GPT-J from jupyter notebook on TPU machine of GCP?

@henk717
Copy link
Author

henk717 commented Feb 14, 2023

Tagging @ultrons since he is the project manager for the TPU's. He may be able to get this to the right person. Thousands depend on MTJ for inference since it can be used to automatically load some huggingface pytorch models on the TPU.

But especially since this is a failure to initialize the TPU at a very basic level. With the 0.1 driver resulting in a broken unresponsive TPU I expect this effects more colab users than the ones depending on MTJ. And if this same firmware bug spreads outside of colab more TPU customers could be effected on the entire google cloud.

@mosmos6
Copy link

mosmos6 commented Feb 14, 2023

I'm subscribing pro for TPU. If it stays uninitializable, it's no use..

@Kipcreate
Copy link

Same error here, trying to run Colab on TPU. GPU alternatives are practically unusable for the stuff I'm doing, so I really need that TPU up and running. Otherwise, my Pro sub ain't worth much of anything.

@candymint23
Copy link

Can confirm this problem with the GPT models I use. I can't run them because of the same problem.

@metrizable
Copy link
Contributor

@henk717 Thanks for reporting the issue and thanks for using Colab. I can confirm that specifying the 0.1dev does not work, but taking the default and specifying the 0.2 drivers does work. Tracking internally at b/269607171.

@somsomers
Copy link

@henk717 Thanks for reporting the issue and thanks for using Colab. I can confirm that specifying the 0.1dev does not work, but taking the default and specifying the 0.2 drivers does work. Tracking internally at b/269607171.

You mean, full driver path would be:

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.2'

?

@henk717
Copy link
Author

henk717 commented Feb 16, 2023

@henk717 Thanks for reporting the issue and thanks for using Colab. I can confirm that specifying the 0.1dev does not work, but taking the default and specifying the 0.2 drivers does work. Tracking internally at b/269607171.

You mean, full driver path would be:

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.2'

?

This is indeed correct , the 0.2 drivers and newer (including the ones that just use a 2022 version number without the other versioning) load fine. If your notebook is compatible with the newer drivers this can solve the issue for you, unfortunately a lot of the notebooks that directly call for a 0.1 driver will break when this is attempting because of incompatibilities.

You can find a sample notebook here : https://colab.research.google.com/drive/1YDcZJ4EMOd3f_kuk0RnD5AJBEpUhMl2I#revisionId=0B7OnP7aLuFgXMXFiZU9sNDZnWmNpVmVzaWc1YlhYaEF6ZnAwPQ

@mosmos6
Copy link

mosmos6 commented Feb 17, 2023

@metrizable Thank you for taking care of this issue. Initializing the default and the 0.2 drivers are possible, but it causes crash when creating network of GPT-J, and probably its derivatives. So unfortunately I don't think these can be a temporary remedy.

@DamascusGit
Copy link

when will this be implemented on henk.tech/ckds? looking to use with colab kobold deployment script as currently have no way of converting my mesh transformer jax weights to HF.

@henk717
Copy link
Author

henk717 commented Feb 22, 2023

when will this be implemented on henk.tech/ckds? looking to use with colab kobold deployment script as currently have no way of converting my mesh transformer jax weights to HF.

You are commenting on a google issue not a Kobold issue. If it was as simple as changing versions I would have done so, but it is not the ckds script that decides this. Mesh Transformers Jax itself requests the broken driver and does so because it is not compatible with anything newer.

Since its a very low level dependency issue I am unable to resolve that myself as it requires deep knowledge of the TPU itself.

@wingk10
Copy link

wingk10 commented Feb 23, 2023

Our operations team uses MTJ on a daily basis and hasn't been able to since the TPUs went down. Really hoping this gets resolved

@mosmos6
Copy link

mosmos6 commented Feb 23, 2023

@wingk10 I don't want to make this post lengthy but same here for me. Kaggle is affected and now the queue for TPU has over 70 users, which means we are likely to wait for 6 hours. We are even losing the alternatives.
That said, I see the situation evolving quietly. On Monday, 0.1 was initializable but didn't create network. Now the default is not initializable either. (but 0.2 is. Seems quite random.)

@wingk10
Copy link

wingk10 commented Feb 23, 2023

Yes, we're running into the same issue re: Kaggle. 64 users in the queue right now, and seems to have gotten worse suddenly over the past few days. Obviously, I can't do anything but post and go "hey, it's important to me too", with no alternatives (extra useless here). I hope we hear more soon.

@mosmos6
Copy link

mosmos6 commented Feb 23, 2023

I decided to pay $10/h and tried to connect vertex AI to cloud TPU, but there was no available TPU in my bucket region. So, there are really no alternatives.

@henk717
Copy link
Author

henk717 commented Mar 17, 2023

Nightly is less desirable than any other newer driver since its always the newest one. It will be broken, back then nightly was a 0.1 driver.

@henk717
Copy link
Author

henk717 commented Mar 17, 2023

@mosmos6 I will also give you a bit of a recap so you can understand why I need the driver to be fixed, but why for you some alternatives might exist.

Mesh Transformers Jax (MTJ) was the framework used to create GPT-J, so GPT-J in its original form runs on top of MTJ. It has been ported to other platforms, so you can also run it on a GPU using Huggingface Transformers for example. And that is how our own community runs GPT-J based models on colab now with a more limited context.

For us the issue is an issue in RAM. The affordable colab GPU's for our AI hobby have 16GB of VRAM, while the TPU has 64GB of RAM. So while GPT-J-6B is possible to run on a GPU, we can not fit as much context as the TPU version could.

In the past year VE-Forbryderne ported various formats, so his version of MTJ can run GPT-J, but also XGLM, OPT and even NeoX based models. And not just that, it can load those models using pytorch files without requiring conversion. This allowed us AI hobbyists to use models up to 20B very affordably on Google Colab which was why the TPU was so desirable for us. $10 a month (or limited free usage) is much better than having to pay $1 per hour on GPU rentals which is not affordable for open source hobbyists who wish to use the models.

If all you want is GPT-J-6B inference I suggest you switch your usage to Huggingface since you will be able to enjoy much better more reliable support on Colab and beyond for the same price. Its when you want the higher model sizes or training that the TPU becomes necessary. And the only platform that has that kind of cost effectiveness is Colab combined with a modified MTJ.

Unfortunately the original 0.1 driver removal happened one month after VE's disappearance, so our dependency is completely unmaintained. If someone in this topic does want the challenge of porting MTJ to a newer Jax version I highly recommend forking https://github.com/VE-FORBRYDERNE/mesh-transformer-jax since it is much more feature complete than the original MTJ, and also more efficient. It even has been used in training 20B NeoX models on TRC.

@mosmos6
Copy link

mosmos6 commented Mar 17, 2023

@henk717 Thank you for recapping. Sorry, I didn't know MTJ means Mesh Transformers Jax. Then it's exactly GPT-J.
My GPT-J model has been heavily finetuned over the past almost 2 years for a research project and it's the only one.
I assume MTJ crashes on 0.2 because of JAX compatibility, and our JAX version needed to stay outdated due to xmap.
It'll be complex but I think it's possible to find an optimal versions of all the dependencies?

@henk717
Copy link
Author

henk717 commented Mar 17, 2023

@mosmos6 In your specific use case its worth checking if conversion scripts like https://github.com/VE-FORBRYDERNE/mesh-transformer-jax/blob/all/to_hf_weights.py are still functional (For example with CPU dependencies) so that you can get your model out of this platform to futureproof your model.

If you are unable to you are stuck with the rest of the thousands of users that have no substitute for MTJ.

As for the dependencies, I lack the ability to do this myself but if others can this would be very welcome.

@mosmos6
Copy link

mosmos6 commented Mar 18, 2023

@henk717 Thank you for suggestion. I've been persistent with TPU because I think running on TPU or GPU makes significant differences in output quality for some reason, but I'll give it a try.
I upgraded my subscription to pro+ to try out all the possible combinations of dependencies. Some people claimed they had to downgrade JAX from 0.4 to 0.3.25 for TPU compatibility over the past few days.
What I'm saying is what's your plan? If 0.1 is entirely removed, at least GPT-J on TPU and its derivatives are technically left dead. We must find a way. When the issue was temporarily "resolved" a few weeks ago, my GPT-J said it's temporary and the issue won't be fixed. It was true. I and my GPT-J model were in the middle of a work to do. Probably many other folks here are the same.

@DevLance112
Copy link

3/22/2023 Connection error still persists.

@mosmos6
Copy link

mosmos6 commented Mar 23, 2023

Hello everyone,

I'm attempting to update MTJ that runs on TPU_driver0.2.
Here are the new requirements, transformer_shard.py and colab demo based on GPT-J. I believe the potential solution will be deployable to all the derivatives.
There were a couple of reasons why the former MTJ crashed on TPU_driver0.2, but generally speaking, it runs if it's updated to JAX 0.3.5.
I've been working on xmap and now it doesn't error and it almost starts creating network as you see below.

Screenshot 2023-03-23 180257

However the code gets stuck at line 265 of transformer_shard.py, which is

self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)

Even though xmap doesn't show any clear error, apparently it's stuck around out_axis of init_xmap. The code is

self.init_xmap = jax.experimental.maps.xmap(init,
                                                    in_axes=(["shard", ...], ["batch", ...]),
                                                    out_axes=["batch", "shard"],
                                                    axis_resources={'shard': 'mp'})

I've been performing an intense research over the past days but I can't find any solution. I thought it's time to ask for everyone's wisdom.

@henk717

@mosmos6
Copy link

mosmos6 commented Mar 23, 2023

Hello everyone,

I'm attempting to update MTJ that runs on TPU_driver0.2. Here are the new requirements, transformer_shard.py and colab demo based on GPT-J. I believe the potential solution will be deployable to all the derivatives. There were a couple of reasons why the former MTJ crashed on TPU_driver0.2, but generally speaking, it runs if it's updated to JAX 0.3.5. I've been working on xmap and now it doesn't error and it almost starts creating network as you see below.

Screenshot 2023-03-23 180257

However the code gets stuck at line 265 of transformer_shard.py, which is

self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)

Even though xmap doesn't show any clear error, apparently it's stuck around out_axis of init_xmap. The code is

self.init_xmap = jax.experimental.maps.xmap(init,
                                                    in_axes=(["shard", ...], ["batch", ...]),
                                                    out_axes=["batch", "shard"],
                                                    axis_resources={'shard': 'mp'})

I've been performing an intense research over the past days but I can't find any solution. I thought it's time to ask for everyone's wisdom.

@henk717

I think my statement about xmap is logical. It doesn’t even visibly error so I really don’t know what is wrong in the codes. The only thing I can think of is that maps.Mesh doesn’t pass all the info from devices to ResourceEnv on JAX0.3.5. So if I specified my question, it would be how to pass all the information from devices by maps.Mesh.

@mosmos6
Copy link

mosmos6 commented Mar 30, 2023

Hello @metrizable @cperry-goog

This issue has started where GPT-J and its derivatives could not be connected to TPU_driver0.1 anymore.

You'll need to upgrade to more recent drivers. Sorry.

Hence, I updated my code so that it runs on JAX 0.3.5, which is compatible to TPU_driver0.2.
Indeed my code is already running well on my V3-8 (v2-alpha) VM and I'm inferring with it on my screen beside this window.

However, the very same code errors at a particular point on colab so I would like you to take a look.
This is a miniature code based on #6962 back in 2021.
Obviously I found an answer to this original question as it's already running on my TPU VM, but colab is having another issue.

import jax
import haiku as hk
import jax.numpy as jnp
import numpy as np
import time


class TransformerLayerShard(hk.Module):
    def __init__(self):
        super().__init__()
        self.dense_proj = hk.Linear(2048)
        self.dense_proj_o = hk.Linear(4096)

    def ff(self, x):
        dense_proj = self.dense_proj(x)
        dense_proj = jax.nn.gelu(dense_proj)
        return self.dense_proj_o(dense_proj)

    def __call__(self, x, attn_bias):
        dense_out = self.ff(x)

        return dense_out


mesh_shape = (1, 8)
devices = np.array(jax.devices()).reshape(mesh_shape)

with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
    def init_old(key, x):
        def init_old_fn(x):
            return TransformerLayerShard()(x, 0)

        param_init_fn = hk.transform(hk.experimental.optimize_rng_use(init_old_fn)).init
        params = param_init_fn(key, x)
        return params

    init_xmap = jax.experimental.maps.xmap(fun=init_old,
                                            in_axes=(["shard", ...],
                                                     ["batch", ...]),
                                            out_axes=["shard", ...],
                                            axis_resources={'shard': 'mp', 'batch': 'dp'})

    key = hk.PRNGSequence(42)

    x = jax.random.uniform(next(key), (1, 1024, 2048))  # batch, len
    params = init_xmap(jnp.array(key.take(8)), x)

    def bwd_old(state, x):
        def bwd_old_fn(x):
            return jnp.sum(TransformerLayerShard()(x, 0))

        train_loss_fn = hk.without_apply_rng(hk.transform(bwd_old_fn)).apply
        val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=False)

        loss, grad = val_grad_fn(state, x)

        return grad

      
    run_xmap = jax.experimental.maps.xmap(fun=bwd_old,
                                          in_axes=(["shard", ...],
                                                     ["batch", ...]),
                                            out_axes=(["shard", "batch", ...], ["batch", ...]),
                                            axis_resources={'shard': 'mp', 'batch': 'dp'})

    run_xmap(params, x)


Before it runs down to the out_axes, code gets stuck at line 46, which is
params = init_xmap(jnp.array(key.take(8)), x)
I waited for 1h but it didn't progress.

Screenshot 2023-03-30 121032

It seems to be a wrapper part of xmap. On TPU VM, this process finishes within a minute.
As it doesn't even show an error message, I have no clue about the problem.

I also tried pjit but it erros as
RuntimeError: UNIMPLEMENTED: Only 1 computation per replica supported, 8 requested.
Again, the pjit itself runs well on TPU VM so it's unique to colab. However, pjit has low priority compared to this xmap issue, for now.

As advised, I updated my code for the newer JAX and TPU_driver, which perfectly runs on TPU VM, but it gets stuck on colab.
Considering the accessibility to TPU VM, availability on colab is highly important.
Thank you for your attention.

@dbubbins87
Copy link

dbubbins87 commented Mar 30, 2023

I'd thought I'd post what I got when trying to run the TPU models as it was a different result:
Downloading (…)lve/main/config.json: 100%|█████████████████████| 1.57k/1.57k [00:00<00:00, 160kB/s]
Traceback (most recent call last):
File "/content/KoboldAI-Client/aiserver.py", line 10284, in
load_model(initial_load=True)
File "/content/KoboldAI-Client/aiserver.py", line 2789, in load_model
import tpu_mtj_backend
File "/content/KoboldAI-Client/tpu_mtj_backend.py", line 51, in
from mesh_transformer.checkpoint import read_ckpt_lowmem
File "/usr/local/lib/python3.9/dist-packages/mesh_transformer/checkpoint.py", line 17, in
from mesh_transformer.util import head_print, to_bf16
File "/usr/local/lib/python3.9/dist-packages/mesh_transformer/util.py", line 5, in
from optax import AdditiveWeightDecayState, GradientTransformation, EmptyState
File "/usr/local/lib/python3.9/dist-packages/optax/init.py", line 17, in
from optax import experimental
File "/usr/local/lib/python3.9/dist-packages/optax/experimental/init.py", line 20, in
from optax._src.experimental.complex_valued import split_real_and_imaginary
File "/usr/local/lib/python3.9/dist-packages/optax/_src/experimental/complex_valued.py", line 32, in
import chex
File "/usr/local/lib/python3.9/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.9/dist-packages/chex/_src/asserts.py", line 26, in
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.9/dist-packages/chex/_src/asserts_internal.py", line 34, in
from chex._src import pytypes
File "/usr/local/lib/python3.9/dist-packages/chex/_src/pytypes.py", line 27, in
ArrayDevice = jax.Array
AttributeError: module 'jax' has no attribute 'Array'

@henk717
Copy link
Author

henk717 commented Mar 30, 2023

AttributeError: module 'jax' has no attribute 'Array' is a new error related to chex doing a breaking change, I fixed this by pinning a suitable version in our requirements files.

Now the error is back to the one reported in this issue tracker.

@mosmos6
Copy link

mosmos6 commented Apr 4, 2023

Hello @metrizable @cperry-goog

I upgraded my model to run on JAX 0.3.25 and colab managed to load the model for the first time in two weeks.

Screenshot 2023-04-04 132430

However, when I try to infer with this model, the same issue as my previous comment occurs again.

Screenshot 2023-04-04 142432

Namely, the code doesn't show error message but it's stuck at a certain point (related to xmap), which is the same operation as the previous comment.

It's at >infer() > generate() > fun_mapped() > bind() > map_bind() > process() > process_call() > xmap_impl() > wrapper() > call()

This makes no sense because the same code runs well on TPU VM v3-8 (v2-alpha) and the same operation was processed well when the model was loaded. I would like you to take a look.

At this moment, the very same code cannot initialize TPU_driver0.2, (RuntimeError: Backend 'tpu_driver' failed to initialize: DEADLINE_EXCEEDED: Failed to connect to remote server at address: grpc://10.110.14.98:8470. Error from gRPC: Deadline Exceeded. Details: ) and the version of my model with JAX 0.3.5 has illogical dependency incompatibility, which never happened last week.

Thank you for your attention.

@mosmos6
Copy link

mosmos6 commented Apr 13, 2023

Hello @metrizable @cperry-goog

I resolved it and the discussed model now runs on colab with TPU_driver0.2.
Please accept my apology for tagging you too often.
Thank you for the great products.

Screenshot 2023-04-13 122413

@henk717
Copy link
Author

henk717 commented Apr 13, 2023

@mosmos6 Can you share your changes? There is still an entire ecosystem broken.

@mosmos6
Copy link

mosmos6 commented Apr 13, 2023

@henk717 ofc. Give me some minutes. Now I'm on my way to set up a repository as changes happened in multiple files. My test code needs clean up after one month experiments.

@henk717
Copy link
Author

henk717 commented Apr 13, 2023

For us the challenge will be getting this one running : https://github.com/VE-FORBRYDERNE/mesh-transformer-jax/tree/ck it is a heavily modified version that has a lot more additions and enhancements but the developer went missing.

@somsomers
Copy link

Hello @metrizable @cperry-goog

I resolved it and the discussed model now runs on colab with TPU_driver0.2. Please accept my apology for tagging you too often. Thank you for the great products.

Screenshot 2023-04-13 122413

Could you please share the working colab notebook if you have one?

@mosmos6
Copy link

mosmos6 commented Apr 13, 2023

@henk717 By casually looking, I suppose you need to update only line 383 - 419 of transformer_shard.py if you use new colab demo to infer followed by updating the breaking changes of jax.

@mosmos6
Copy link

mosmos6 commented Apr 13, 2023

@somsomers Yes. Please let me clean up the mess before sharing.

@somsomers
Copy link

@somsomers Yes. Please let me clean up the mess before sharing.

Thank you.

@mosmos6
Copy link

mosmos6 commented Apr 13, 2023

Hello,

First of all, I must apologize. This works only on high memory TPU runtime so you'll need pro or pro+ subscription of colab....
However, I modified the discussed model (GPT-J for me) so that it runs with TPU_driver0.2 on colab.
Because it is not exactly linked to colab, I posted it here.
(kingoflolz/mesh-transformer-jax#256 (comment))
You can continue to use the same (slim) weights as before.
I believe this can be deployed to other derivatives.

@henk717 @somsomers

@mosmos6
Copy link

mosmos6 commented Apr 19, 2023

@henk717

For us the challenge will be getting this one running : https://github.com/VE-FORBRYDERNE/mesh-transformer-jax/tree/ck

I fixed your AI. She's waiting for you to pick up in the garage. (https://github.com/mosmos6/Large-MTJ)

Same as my GPT-J, it's adapted to JAX 0.3.25 so it runs on colab with TPU_driver0.2. Basically this should be now immunized to JAX upgrading except breaking changes. Sorry for the dorky name, I didn't know her name.
The changes were small but many. I modified
requirements.txt (kept the original one as _original)
slim_model.py
mesh_transformer/train_actor.py
device_sample.py
device_serve.py
device_train.py
mesh_transformer/transformer_shard.py
mesh_transformer/checkpoint.py
The new colab demo is Large_MTJ_inference_on_TPU_driver0_2.ipynb (kept the original one as _original)
The rests remain the same.

I tested this only with my slim weights for GPT-J. If you run into an error with other types of weights, please post an issue.

Important notes;

  1. Sorry, you'll need pro or pro+ subscription of colab because it requires high memory TPU runtime. read_ckpt_lowmem hangs forever when to infer with the model in the current colab environment. I had to revive read_ckpt. However this loads the model 10 times faster than low memory version. It quit showing total parameters for some reason too, but I don't think it matters.

  2. I have not checked it for finetuning on TPU VM yet. This can cause errors during a process. I'm planning to cover it next month. Until then, possibly you must add further modifications to xmap by yourself or downgrade to jax 0.2.18 or 0.2.20.

  3. Please let me know if you don't know how to use the new colab demo.

Enjoy

Screenshot 2023-04-19 120424

@henk717
Copy link
Author

henk717 commented Apr 19, 2023

@mosmos6 I tried applying the modifications to my test account here but the end result is gibberish.

To test you can take this notebook and replace the version field with https://github.com/henk7171/koboldai.
There is a lot more stuff to it in the tpu_mtj_backend.py file including the automatic conversion of huggingface models, it did this ram efficient but not with a working end result.

@mosmos6
Copy link

mosmos6 commented Apr 19, 2023

@henk717

I saw your tpu_mtj_backend.py, but as I wrote above, you can’t use read_ckpt_lowmem anymore on colab.
and in this file, you also need to update xmap out_axis in some functions. Also, as I wrote on my colab, jax.tools.colab_tpu must be installed before installing jax when it's v 0.3.25 or it leads to misconfiguration. Finally, you need to update maps.ResourceEnv because it needs loops in the newer version.

@mosmos6
Copy link

mosmos6 commented Apr 28, 2023

Due to the python upgrade of colab (3.9 -> 3.10), I further modified two of my modified mtj models, and requirements.txt and util.py of each are updated.
Now these models can adapt to the later versions of optax than 0.0.9. They are immune to JAX upgrades and optax upgrades.

@dbubbins87
Copy link

dbubbins87 commented May 5, 2023

So I'm not sure what happened, but it started working for a week, but just when I tried to use it tonight, some of the models ended up with the error again.

@henk717
Copy link
Author

henk717 commented May 5, 2023

If you are a Kobold user its because we implemented 2.0 support. TPU's have always been a bit unreliable and usually running the notebook again is enough.

Are there people left who still depend on 0.1? Otherwise it no longer makes sense to keep this open.

@sagelywizard
Copy link
Member

This issue is obsolete because the TPU runtimes are deprecated and were removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests