Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning for JAX versions on import of Dynamics #232

Merged
merged 10 commits into from
Jun 17, 2023
4 changes: 4 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,7 @@
nbsphinx_execute = 'always'
nbsphinx_widgets_path = ''
exclude_patterns = ['_build', '**.ipynb_checkpoints']

# this is tied to the temporary restriction to JAX versions <=0.4.6. See issue #190
import os
os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"
36 changes: 18 additions & 18 deletions qiskit_dynamics/dispatch/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,28 @@

try:
import jax
from jax.interpreters.xla import DeviceArray
from jax import Array
from jax.core import Tracer
from jax.interpreters.ad import JVPTracer
from jax.interpreters.partial_eval import JaxprTracer

JAX_TYPES = (DeviceArray, Tracer, JaxprTracer, JVPTracer)
# warning based on JAX version
from packaging import version
import warnings

try:
# This class was introduced in 0.4.0
from jax import Array
if version.parse(jax.__version__) >= version.parse("0.4.4"):
import os

JAX_TYPES += (Array,)
except ImportError:
pass

try:
# This class is not in older versions of Jax
from jax.interpreters.partial_eval import DynamicJaxprTracer
if (
version.parse(jax.__version__) > version.parse("0.4.6")
or os.environ.get("JAX_JIT_PJIT_API_MERGE", None) != "0"
):
warnings.warn(
"The functionality in the perturbation module of Qiskit Dynamics requires a JAX "
"version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, "
"0.4.5, and 0.4.6, using the perturbation module functionality requires setting "
"os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX or Dynamics."
)

JAX_TYPES += (DynamicJaxprTracer,)
except ImportError:
pass
JAX_TYPES = (Array, Tracer)

from ..dispatch import Dispatch
import numpy as np
Expand All @@ -53,7 +53,7 @@
def _jax_asarray(array, dtype=None, order=None):
"""Wrapper for jax.numpy.asarray"""
if (
isinstance(array, DeviceArray)
isinstance(array, JAX_TYPES)
and order is None
and (dtype is None or dtype == array.dtype)
):
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
"sympy>=1.12"
]

jax_extras = ['jax>=0.2.26, <= 0.4.6',
'jaxlib>=0.1.75, <= 0.4.6']
jax_extras = ['jax>=0.4.0, <= 0.4.6',
'jaxlib>=0.4.0, <= 0.4.6']

PACKAGES = setuptools.find_packages(exclude=['test*'])

Expand Down