Skip to content

Commit

Permalink
setting jax version warning to only trigger if the version is between…
Browse files Browse the repository at this point in the history
… 0.4.4 and 0.4.6 and the os flag hasn't been set
  • Loading branch information
DanPuzzuoli committed Jun 16, 2023
1 parent a483084 commit 9b4f195
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
14 changes: 8 additions & 6 deletions qiskit_dynamics/dispatch/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
import warnings

if version.parse(jax.__version__) >= version.parse("0.4.4"):
warnings.warn(
"The functionality in the perturbation module of Qiskit Dynamics requires a JAX "
"version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, 0.4.5, "
"and 0.4.6, using the perturbation module functionality requires setting "
"os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX."
)
import os
if version.parse(jax.__version__) > version.parse("0.4.6") or os.environ.get('JAX_JIT_PJIT_API_MERGE', None) != '0':
warnings.warn(
"The functionality in the perturbation module of Qiskit Dynamics requires a JAX "
"version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, "
"0.4.5, and 0.4.6, using the perturbation module functionality requires setting "
"os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX or Dynamics."
)

JAX_TYPES = (Array, Tracer)

Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ commands = black {posargs} qiskit_dynamics test
usedevelop = False
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.3
jaxlib<=0.4.3
jax<=0.4.6
jaxlib<=0.4.6
diffrax
commands =
sphinx-build -j auto -b html -W {posargs} docs/ docs/_build/html
Expand Down

0 comments on commit 9b4f195

Please sign in to comment.