From 969211193bbac6e3db80faebc329a569902da365 Mon Sep 17 00:00:00 2001 From: DanPuzzuoli Date: Thu, 12 Oct 2023 16:34:12 -0700 Subject: [PATCH] removing restrictions on JAX/Diffrax versions in tox/setup --- docs/conf.py | 6 +----- qiskit_dynamics/perturbation/dyson_magnus.py | 6 ++++-- setup.py | 3 +-- tox.ini | 18 ++++++++---------- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 809be684a..5a9f23f3d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -81,8 +81,4 @@ # nbsphinx_execute = os.getenv('QISKIT_DOCS_BUILD_TUTORIALS', 'never') nbsphinx_execute = 'always' nbsphinx_widgets_path = '' -exclude_patterns = ['_build', '**.ipynb_checkpoints'] - -# this is tied to the temporary restriction to JAX versions <=0.4.6. See issue #190 -import os -os.environ["JAX_JIT_PJIT_API_MERGE"] = "0" \ No newline at end of file +exclude_patterns = ['_build', '**.ipynb_checkpoints'] \ No newline at end of file diff --git a/qiskit_dynamics/perturbation/dyson_magnus.py b/qiskit_dynamics/perturbation/dyson_magnus.py index 948088fdc..082d3ada1 100644 --- a/qiskit_dynamics/perturbation/dyson_magnus.py +++ b/qiskit_dynamics/perturbation/dyson_magnus.py @@ -473,7 +473,10 @@ def _setup_dyson_rhs_jax( custom_matmul = _CustomMatmul(lmult_rule, index_offset=1, backend="jax") """ - ##################################################################################################OLD version + ################################################################################################## + #Old version - may want to consider keeping this or moving into a test that detects when JAX no + # longer has an issue with this + perturbations_evaluation_order = jnp.array(perturbations_evaluation_order, dtype=int) new_list = [generator] + perturbations @@ -496,7 +499,6 @@ def multiple_eval(t): def dyson_rhs(t, y): return custom_matmul(multiple_eval(t), y) - return dyson_rhs diff --git a/setup.py b/setup.py index 937f9d879..a8e1fcd68 100644 --- a/setup.py +++ b/setup.py @@ -25,8 +25,7 @@ "arraylias" ] -jax_extras = ['jax>=0.4.0, <= 0.4.6', - 'jaxlib>=0.4.0, <= 0.4.6'] +jax_extras = ['jax', 'jaxlib'] PACKAGES = setuptools.find_packages(exclude=['test*']) diff --git a/tox.ini b/tox.ini index 1efa34cc5..d1f480f9f 100644 --- a/tox.ini +++ b/tox.ini @@ -15,18 +15,16 @@ commands = stestr run {posargs} [testenv:jax] deps = -r{toxinidir}/requirements-dev.txt - jax<=0.4.6 - jaxlib<=0.4.6 - equinox<=0.10.3 - diffrax<=0.3.1 + jax + jaxlib + diffrax [testenv:lint] deps = -r{toxinidir}/requirements-dev.txt - jax<=0.4.6 - jaxlib<=0.4.6 - equinox<=0.10.3 - diffrax<=0.3.1 + jax + jaxlib + diffrax commands = black --check {posargs} qiskit_dynamics test pylint -rn -j 0 --rcfile={toxinidir}/.pylintrc qiskit_dynamics/ test/ @@ -41,8 +39,8 @@ commands = black {posargs} qiskit_dynamics test usedevelop = False deps = -r{toxinidir}/requirements-dev.txt - jax<=0.4.6 - jaxlib<=0.4.6 + jax + jaxlib diffrax commands = sphinx-build -j auto -b html -W {posargs} docs/ docs/_build/html