Skip to content

Commit

Permalink
removing restrictions on JAX/Diffrax versions in tox/setup
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli committed Oct 12, 2023
1 parent 8c4a0a3 commit 9692111
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 19 deletions.
6 changes: 1 addition & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,4 @@
# nbsphinx_execute = os.getenv('QISKIT_DOCS_BUILD_TUTORIALS', 'never')
nbsphinx_execute = 'always'
nbsphinx_widgets_path = ''
exclude_patterns = ['_build', '**.ipynb_checkpoints']

# this is tied to the temporary restriction to JAX versions <=0.4.6. See issue #190
import os
os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"
exclude_patterns = ['_build', '**.ipynb_checkpoints']
6 changes: 4 additions & 2 deletions qiskit_dynamics/perturbation/dyson_magnus.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,10 @@ def _setup_dyson_rhs_jax(
custom_matmul = _CustomMatmul(lmult_rule, index_offset=1, backend="jax")

"""
##################################################################################################OLD version
##################################################################################################
#Old version - may want to consider keeping this or moving into a test that detects when JAX no
# longer has an issue with this
perturbations_evaluation_order = jnp.array(perturbations_evaluation_order, dtype=int)
new_list = [generator] + perturbations
Expand All @@ -496,7 +499,6 @@ def multiple_eval(t):
def dyson_rhs(t, y):
return custom_matmul(multiple_eval(t), y)


return dyson_rhs


Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
"arraylias"
]

jax_extras = ['jax>=0.4.0, <= 0.4.6',
'jaxlib>=0.4.0, <= 0.4.6']
jax_extras = ['jax', 'jaxlib']

PACKAGES = setuptools.find_packages(exclude=['test*'])

Expand Down
18 changes: 8 additions & 10 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@ commands = stestr run {posargs}
[testenv:jax]
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.6
jaxlib<=0.4.6
equinox<=0.10.3
diffrax<=0.3.1
jax
jaxlib
diffrax

[testenv:lint]
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.6
jaxlib<=0.4.6
equinox<=0.10.3
diffrax<=0.3.1
jax
jaxlib
diffrax
commands =
black --check {posargs} qiskit_dynamics test
pylint -rn -j 0 --rcfile={toxinidir}/.pylintrc qiskit_dynamics/ test/
Expand All @@ -41,8 +39,8 @@ commands = black {posargs} qiskit_dynamics test
usedevelop = False
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.6
jaxlib<=0.4.6
jax
jaxlib
diffrax
commands =
sphinx-build -j auto -b html -W {posargs} docs/ docs/_build/html
Expand Down

0 comments on commit 9692111

Please sign in to comment.