Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] XLA is incompatible with jax 0.4.29 #313

Open
3 tasks done
pseudo-rnd-thoughts opened this issue Aug 12, 2024 · 2 comments · May be fixed by #314
Open
3 tasks done

[BUG] XLA is incompatible with jax 0.4.29 #313

pseudo-rnd-thoughts opened this issue Aug 12, 2024 · 2 comments · May be fixed by #314
Assignees

Comments

@pseudo-rnd-thoughts
Copy link

Describe the bug

jax.interpreters.xla.backend_specific_translations is deprecated in jax v0.4.29
https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-29-june-10-2024

This causes the following error when running in xla mode AttributeError: jax.interpreters.xla.backend_specific_translations is deprecated. Register custom primitives via jax.interpreters.mlir instead.

To Reproduce

import envpool

env = envpool.make("Breakout-v5")
env.xla()
    handle, recv, send, step = env.xla()
  File "/lib/python3.10/site-packages/envpool/python/lax.py", line 30, in xla
    _handle, _recv, _send = make_xla(self)
  File "/lib/python3.10/site-packages/envpool/python/xla_template.py", line 124, in make_xla
    methods.append(_make_xla_function(obj, handle, name, specs, capsules))
  File "/lib/python3.10/site-packages/envpool/python/xla_template.py", line 94, in _make_xla_function
    xla.backend_specific_translations["cpu"][prim] = partial(
  File "/lib/python3.10/site-packages/jax/_src/deprecations.py", line 52, in getattr
    raise AttributeError(message)
AttributeError: jax.interpreters.xla.backend_specific_translations is deprecated. Register custom primitives via jax.interpreters.mlir instead.

Expected behavior

xla function is created

System info

import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)
> 0.8.4 1.26.4 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] linux
print(jax.__version__)
> 0.4.29

Reason and Possible fixes

According to the error message then we should use Register custom primitives via jax.interpreters.mlir instead

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@pseudo-rnd-thoughts
Copy link
Author

Closed by accident

@Trinkle23897
Copy link
Collaborator

sorry about that, I'll try fixing the ci if I have time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants