Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU training colab is not working for me #224

Closed
ViktorM opened this issue Sep 14, 2022 · 4 comments
Closed

TPU training colab is not working for me #224

ViktorM opened this issue Sep 14, 2022 · 4 comments

Comments

@ViktorM
Copy link

ViktorM commented Sep 14, 2022

TPU training colab stopped working for me after one of the updates in recent months. It stalls when comes to the training cell it shows that some work is going on but it never finishes, and no plots and trained policies are produced.

@btaba btaba added the good first issue Good for newcomers label Jan 13, 2023
@erwincoumans
Copy link

Hi, tried this TPU training colab today, and it fails with this error:


AttributeError                            Traceback (most recent call last)
[<ipython-input-3-9ce5fdb19302>](https://localhost:8080/#) in <module>
     14 
     15 try:
---> 16   import brax
     17 except ImportError:
     18   get_ipython().system('pip install git+https://github.com/google/brax.git@main')

2 frames
[/usr/local/lib/python3.9/dist-packages/brax/jumpy.py](https://localhost:8080/#) in <module>
    504 
    505 
--> 506 def where(condition: jax.typing.ArrayLike, x: jax.typing.ArrayLike,
    507           y: jax.typing.ArrayLike) -> ndarray:
    508   """Return elements chosen from `x` or `y` depending on `condition`."""

AttributeError: module 'jax' has no attribute 'typing'

@btaba
Copy link
Collaborator

btaba commented Mar 20, 2023

Hi @erwincoumans! I'm not able to reproduce the issue, which jax version are you using?

@erwincoumans
Copy link

erwincoumans commented Mar 21, 2023

I'm just running the public colab, following the links on the github front page. Did you try training it using a public TPU runtime? The other public colab (training using PyTorch) is also still broken, see the other issue.

image

https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb

Just tried it again, here is the output:


---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-1-9ce5fdb19302>](https://localhost:8080/#) in <module>
     15 try:
---> 16   import brax
     17 except ImportError:

ModuleNotFoundError: No module named 'brax'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
3 frames
[/usr/local/lib/python3.9/dist-packages/brax/jumpy.py](https://localhost:8080/#) in <module>
    504 
    505 
--> 506 def where(condition: jax.typing.ArrayLike, x: jax.typing.ArrayLike,
    507           y: jax.typing.ArrayLike) -> ndarray:
    508   """Return elements chosen from `x` or `y` depending on `condition`."""

AttributeError: module 'jax' has no attribute 'typing'

@btaba
Copy link
Collaborator

btaba commented Mar 21, 2023

Ok thanks for the pointer! It turns out that jax>=0.4.6 is incompatible with public colab TPU runtimes (see https://stackoverflow.com/a/75734517). We're pinning the jax/jaxlib versions to >=0.4.6 now, so it's best to run in a GPU runtime for the time being until the colab issue is fixed
I've confirmed training works on GPU in a public colab runtime

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants