Skip to content

Commit

Permalink
Remove bound on JAX/Diffrax versions (#266)
Browse files Browse the repository at this point in the history
Co-authored-by: Kento Ueda <[email protected]>
  • Loading branch information
DanPuzzuoli and to24toro authored Oct 30, 2023
1 parent 4b67efb commit 594481c
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 51 deletions.
6 changes: 1 addition & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,4 @@
# nbsphinx_execute = os.getenv('QISKIT_DOCS_BUILD_TUTORIALS', 'never')
nbsphinx_execute = 'always'
nbsphinx_widgets_path = ''
exclude_patterns = ['_build', '**.ipynb_checkpoints']

# this is tied to the temporary restriction to JAX versions <=0.4.6. See issue #190
import os
os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"
exclude_patterns = ['_build', '**.ipynb_checkpoints']
18 changes: 0 additions & 18 deletions qiskit_dynamics/dispatch/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,6 @@
from jax import Array
from jax.core import Tracer

# warning based on JAX version
from packaging import version
import warnings

if version.parse(jax.__version__) >= version.parse("0.4.4"):
import os

if (
version.parse(jax.__version__) > version.parse("0.4.6")
or os.environ.get("JAX_JIT_PJIT_API_MERGE", None) != "0"
):
warnings.warn(
"The functionality in the perturbation module of Qiskit Dynamics requires a JAX "
"version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, "
"0.4.5, and 0.4.6, using the perturbation module functionality requires setting "
"os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX or Dynamics."
)

JAX_TYPES = (Array, Tracer)

from ..dispatch import Dispatch
Expand Down
12 changes: 5 additions & 7 deletions qiskit_dynamics/perturbation/dyson_magnus.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

try:
import jax.numpy as jnp
from jax.lax import scan, switch
from jax.lax import scan
from jax import vmap
except ImportError:
pass
Expand Down Expand Up @@ -472,17 +472,15 @@ def _setup_dyson_rhs_jax(

custom_matmul = _CustomMatmul(lmult_rule, index_offset=1, backend="jax")

perturbations_evaluation_order = jnp.array(perturbations_evaluation_order, dtype=int)
perturbations_evaluation_order = np.array(perturbations_evaluation_order, dtype=int)

new_list = [generator] + perturbations

def single_eval(idx, t):
return switch(idx, new_list, t)

multiple_eval = vmap(single_eval, in_axes=(0, None))
def multiple_eval(t):
return jnp.array([new_list[idx](t) for idx in perturbations_evaluation_order])

def dyson_rhs(t, y):
return custom_matmul(multiple_eval(perturbations_evaluation_order, t), y)
return custom_matmul(multiple_eval(t), y)

return dyson_rhs

Expand Down
2 changes: 1 addition & 1 deletion qiskit_dynamics/solvers/lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def lanczos_basis(A: Union[csr_matrix, np.ndarray], y0: np.ndarray, k_dim: int):

q_basis[[0], :] = y0.T
projection = A @ y0
alpha[0] = y0.conj().T @ projection
alpha[0] = np.sum(y0.conj() * projection)
projection = projection - alpha[0] * y0
beta[0] = np.linalg.norm(projection)

Expand Down
10 changes: 10 additions & 0 deletions releasenotes/notes/update-jax-a50ce1b7d6b47219.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
issues:
- |
A JAX warning about casting complex values to real is raised when computing gradients of
simulations in Qiskit Dynamics. Note that this warning does not appear to signify any error
in numerical computation, and can be safely ignored.
upgrade:
- |
The upper bound on JAX and Diffrax in the last version of Qiskit Dynamics has been removed.
Users should try to use the latest version of JAX.
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
"arraylias"
]

jax_extras = ['jax>=0.4.0, <= 0.4.6',
'jaxlib>=0.4.0, <= 0.4.6']
jax_extras = ['jax', 'jaxlib']

PACKAGES = setuptools.find_packages(exclude=['test*'])

Expand Down
5 changes: 0 additions & 5 deletions test/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,3 @@
"""
Qiskit Dynamics tests
"""

# temporarily disable a change in JAX 0.4.4 that introduced a bug. Must be run before importing JAX
import os

os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"
3 changes: 2 additions & 1 deletion test/dynamics/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def jit_grad_wrap(self, func_to_test: Callable) -> Callable:
Args:
func_to_test: The function whose gradient will be graded.
Returns:
JIT-compiled gradient of function."""
JIT-compiled gradient of function.
"""
wf = wrap(lambda f: jit(grad(f)), decorator=True)
f = lambda *args: np.sum(func_to_test(*args)).real
return wf(f)
48 changes: 48 additions & 0 deletions test/dynamics/perturbation/test_dyson_magnus.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,51 @@ def assertMultRulesEqual(self, rule1, rule2):
for sub_rule1, sub_rule2 in zip(rule1, rule2):
self.assertAllClose(sub_rule1[0], sub_rule2[0])
self.assertAllClose(sub_rule1[1], sub_rule2[1])


class TestWorkaround(QiskitDynamicsTestCase):
"""Test whether workaround in dyson_magnus._setup_dyson_rhs_jax is no longer required.
The workaround was introduced in the same commit as this test class to avoid an error being
raised by a non-trivial combination of JAX transformations. The test in this class has been
set up to expect the original minimal reproduction of the issue to fail. Once it no longer
fails, the changes made to _setup_dyson_rhs_jax in this commit should be reverted.
See https://github.com/google/jax/discussions/9951#discussioncomment-2385157 for discussion of
issue.
"""

def test_minimal_example(self):
"""Test minimal reproduction of issue."""

with self.assertRaises(Exception):
import jax.numpy as jnp
from jax import grad, vmap
from jax.lax import switch
from jax.experimental.ode import odeint

# pylint: disable=unused-argument
def A0(t):
return 2.0

# pylint: disable=unused-argument
def A1(a, t):
return a**2

y0 = np.random.rand(2)
T = np.pi * 1.232

def test_func(a):
eval_list = [A0, lambda t: A1(a, t)]

def single_eval(idx, t):
return switch(idx, eval_list, t)

multiple_eval = vmap(single_eval, in_axes=(0, None))
idx_list = jnp.array([0, 1])
rhs = lambda y, t: multiple_eval(idx_list, t) * y

out = odeint(rhs, y0=y0, t=jnp.array([0, T], dtype=float), atol=1e-13, rtol=1e-13)
return out

jit(grad(lambda a: test_func(a)[-1][1].real))(1.0)
4 changes: 2 additions & 2 deletions test/dynamics/solvers/test_solver_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,8 +957,8 @@ def test_two_channel_SuperOp_simulation(self, model):
schedules=sched,
signals=signals,
test_tol=1e-8,
atol=1e-11,
rtol=1e-11,
atol=1e-12,
rtol=1e-12,
)

def test_4_channel_schedule(self):
Expand Down
18 changes: 8 additions & 10 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@ commands = stestr run {posargs}
[testenv:jax]
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.6
jaxlib<=0.4.6
equinox<=0.10.3
diffrax<=0.3.1
jax
jaxlib
diffrax

[testenv:lint]
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.6
jaxlib<=0.4.6
equinox<=0.10.3
diffrax<=0.3.1
jax
jaxlib
diffrax
commands =
black --check {posargs} qiskit_dynamics test
pylint -rn -j 0 --rcfile={toxinidir}/.pylintrc qiskit_dynamics/ test/
Expand All @@ -41,8 +39,8 @@ commands = black {posargs} qiskit_dynamics test
usedevelop = False
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.6
jaxlib<=0.4.6
jax
jaxlib
diffrax
commands =
sphinx-build -j auto -W -T --keep-going {posargs} docs/ docs/_build/html
Expand Down

0 comments on commit 594481c

Please sign in to comment.