Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning for JAX versions on import of Dynamics #232

Merged
merged 10 commits into from
Jun 17, 2023
29 changes: 16 additions & 13 deletions qiskit_dynamics/dispatch/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,29 @@

try:
import jax
from jax.interpreters.xla import DeviceArray
from jax import Array
from jax.core import Tracer
from jax.interpreters.ad import JVPTracer
from jax.interpreters.partial_eval import JaxprTracer

JAX_TYPES = (DeviceArray, Tracer, JaxprTracer, JVPTracer)
# warning based on JAX version
from packaging import version
import warnings

try:
# This class was introduced in 0.4.0
from jax import Array
if version.parse(jax.__version__) >= version.parse("0.4.4"):
warnings.warn(
"The functionality in the perturbation module of Qiskit Dynamics requires a JAX "
"version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, 0.4.5, "
"and 0.4.6, using the perturbation module functionality requires setting "
"os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX."
wshanks marked this conversation as resolved.
Show resolved Hide resolved
)

JAX_TYPES += (Array,)
except ImportError:
pass
JAX_TYPES = (Array, Tracer)

# in versions <= 0.4.10
try:
# This class is not in older versions of Jax
from jax.interpreters.partial_eval import DynamicJaxprTracer
# pylint: disable=ungrouped-imports
from jax.interpreters.xla import DeviceArray

JAX_TYPES += (DynamicJaxprTracer,)
JAX_TYPES += (DeviceArray,)
wshanks marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
pass

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
"multiset>=3.0.1",
]

jax_extras = ['jax>=0.2.26, <= 0.4.6',
'jaxlib>=0.1.75, <= 0.4.6']
jax_extras = ['jax>=0.4.0, <= 0.4.6',
'jaxlib>=0.4.0, <= 0.4.6']

PACKAGES = setuptools.find_packages(exclude=['test*'])

Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ commands = black {posargs} qiskit_dynamics test
usedevelop = False
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.6
jaxlib<=0.4.6
jax<=0.4.3
wshanks marked this conversation as resolved.
Show resolved Hide resolved
jaxlib<=0.4.3
diffrax
commands =
sphinx-build -j auto -b html -W {posargs} docs/ docs/_build/html
Expand Down