Skip to content

Commit

Permalink
Add warning for JAX versions on import of Dynamics (qiskit-community#232
Browse files Browse the repository at this point in the history
)
  • Loading branch information
DanPuzzuoli committed Jun 20, 2023
1 parent 79977af commit 360ab09
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
4 changes: 4 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,7 @@
nbsphinx_execute = 'always'
nbsphinx_widgets_path = ''
exclude_patterns = ['_build', '**.ipynb_checkpoints']

# this is tied to the temporary restriction to JAX versions <=0.4.6. See issue #190
import os
os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"
36 changes: 18 additions & 18 deletions qiskit_dynamics/dispatch/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,28 @@

try:
import jax
from jax.interpreters.xla import DeviceArray
from jax import Array
from jax.core import Tracer
from jax.interpreters.ad import JVPTracer
from jax.interpreters.partial_eval import JaxprTracer

JAX_TYPES = (DeviceArray, Tracer, JaxprTracer, JVPTracer)
# warning based on JAX version
from packaging import version
import warnings

try:
# This class was introduced in 0.4.0
from jax import Array
if version.parse(jax.__version__) >= version.parse("0.4.4"):
import os

JAX_TYPES += (Array,)
except ImportError:
pass

try:
# This class is not in older versions of Jax
from jax.interpreters.partial_eval import DynamicJaxprTracer
if (
version.parse(jax.__version__) > version.parse("0.4.6")
or os.environ.get("JAX_JIT_PJIT_API_MERGE", None) != "0"
):
warnings.warn(
"The functionality in the perturbation module of Qiskit Dynamics requires a JAX "
"version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, "
"0.4.5, and 0.4.6, using the perturbation module functionality requires setting "
"os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX or Dynamics."
)

JAX_TYPES += (DynamicJaxprTracer,)
except ImportError:
pass
JAX_TYPES = (Array, Tracer)

from ..dispatch import Dispatch
import numpy as np
Expand All @@ -53,7 +53,7 @@
def _jax_asarray(array, dtype=None, order=None):
"""Wrapper for jax.numpy.asarray"""
if (
isinstance(array, DeviceArray)
isinstance(array, JAX_TYPES)
and order is None
and (dtype is None or dtype == array.dtype)
):
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
"sympy>=1.12"
]

jax_extras = ['jax>=0.2.26, <= 0.4.6',
'jaxlib>=0.1.75, <= 0.4.6']
jax_extras = ['jax>=0.4.0, <= 0.4.6',
'jaxlib>=0.4.0, <= 0.4.6']

PACKAGES = setuptools.find_packages(exclude=['test*'])

Expand Down

0 comments on commit 360ab09

Please sign in to comment.