Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove bound on JAX/Diffrax versions #266

Merged
merged 16 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,4 @@
# nbsphinx_execute = os.getenv('QISKIT_DOCS_BUILD_TUTORIALS', 'never')
nbsphinx_execute = 'always'
nbsphinx_widgets_path = ''
exclude_patterns = ['_build', '**.ipynb_checkpoints']

# this is tied to the temporary restriction to JAX versions <=0.4.6. See issue #190
import os
os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"
exclude_patterns = ['_build', '**.ipynb_checkpoints']
18 changes: 0 additions & 18 deletions qiskit_dynamics/dispatch/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,6 @@
from jax import Array
from jax.core import Tracer

# warning based on JAX version
from packaging import version
import warnings

if version.parse(jax.__version__) >= version.parse("0.4.4"):
import os

if (
version.parse(jax.__version__) > version.parse("0.4.6")
or os.environ.get("JAX_JIT_PJIT_API_MERGE", None) != "0"
):
warnings.warn(
"The functionality in the perturbation module of Qiskit Dynamics requires a JAX "
"version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, "
"0.4.5, and 0.4.6, using the perturbation module functionality requires setting "
"os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX or Dynamics."
)

JAX_TYPES = (Array, Tracer)

from ..dispatch import Dispatch
Expand Down
12 changes: 5 additions & 7 deletions qiskit_dynamics/perturbation/dyson_magnus.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

try:
import jax.numpy as jnp
from jax.lax import scan, switch
from jax.lax import scan
from jax import vmap
except ImportError:
pass
Expand Down Expand Up @@ -472,17 +472,15 @@ def _setup_dyson_rhs_jax(

custom_matmul = _CustomMatmul(lmult_rule, index_offset=1, backend="jax")

perturbations_evaluation_order = jnp.array(perturbations_evaluation_order, dtype=int)
perturbations_evaluation_order = np.array(perturbations_evaluation_order, dtype=int)

new_list = [generator] + perturbations

def single_eval(idx, t):
return switch(idx, new_list, t)

multiple_eval = vmap(single_eval, in_axes=(0, None))
def multiple_eval(t):
return jnp.array([new_list[idx](t) for idx in perturbations_evaluation_order])

def dyson_rhs(t, y):
return custom_matmul(multiple_eval(perturbations_evaluation_order, t), y)
return custom_matmul(multiple_eval(t), y)

return dyson_rhs

Expand Down
2 changes: 1 addition & 1 deletion qiskit_dynamics/solvers/lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def lanczos_basis(A: Union[csr_matrix, np.ndarray], y0: np.ndarray, k_dim: int):

q_basis[[0], :] = y0.T
projection = A @ y0
alpha[0] = y0.conj().T @ projection
alpha[0] = np.sum(y0.conj() * projection)
projection = projection - alpha[0] * y0
beta[0] = np.linalg.norm(projection)

Expand Down
10 changes: 10 additions & 0 deletions releasenotes/notes/update-jax-a50ce1b7d6b47219.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
issues:
- |
A JAX warning about casting complex values to real is raised when when computing gradients of
DanPuzzuoli marked this conversation as resolved.
Show resolved Hide resolved
simulations in Qiskit Dynamics. Note that this warning does not appear to signify any error
in numerical computation, and can be safely ignored.
upgrade:
- |
The upper bound on JAX and Diffrax in the last version of Qiskit Dynamics has been removed.
Users should try to use the latest version of JAX.
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
"arraylias"
]

jax_extras = ['jax>=0.4.0, <= 0.4.6',
'jaxlib>=0.4.0, <= 0.4.6']
jax_extras = ['jax', 'jaxlib']

PACKAGES = setuptools.find_packages(exclude=['test*'])

Expand Down
5 changes: 0 additions & 5 deletions test/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,3 @@
"""
Qiskit Dynamics tests
"""

# temporarily disable a change in JAX 0.4.4 that introduced a bug. Must be run before importing JAX
import os

os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"
3 changes: 2 additions & 1 deletion test/dynamics/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def jit_grad_wrap(self, func_to_test: Callable) -> Callable:
Args:
func_to_test: The function whose gradient will be graded.
Returns:
JIT-compiled gradient of function."""
JIT-compiled gradient of function.
"""
wf = wrap(lambda f: jit(grad(f)), decorator=True)
f = lambda *args: np.sum(func_to_test(*args)).real
return wf(f)
48 changes: 48 additions & 0 deletions test/dynamics/perturbation/test_dyson_magnus.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,51 @@ def assertMultRulesEqual(self, rule1, rule2):
for sub_rule1, sub_rule2 in zip(rule1, rule2):
self.assertAllClose(sub_rule1[0], sub_rule2[0])
self.assertAllClose(sub_rule1[1], sub_rule2[1])


class TestWorkaround(QiskitDynamicsTestCase):
"""Test whether workaround in dyson_magnus._setup_dyson_rhs_jax is no longer required.

The workaround was introduced in the same commit as this test class to avoid an error being
raised by a non-trivial combination of JAX transformations. The test in this class has been
set up to expect the original minimal reproduction of the issue to fail. Once it no longer
fails, the changes made to _setup_dyson_rhs_jax in this commit should be reverted.

See https://github.com/google/jax/discussions/9951#discussioncomment-2385157 for discussion of
issue.
"""

def test_minimal_example(self):
"""Test minimal reproduction of issue."""

with self.assertRaises(Exception):
import jax.numpy as jnp
from jax import grad, vmap
from jax.lax import switch
from jax.experimental.ode import odeint

# pylint: disable=unused-argument
def A0(t):
return 2.0

# pylint: disable=unused-argument
def A1(a, t):
return a**2

y0 = np.random.rand(2)
T = np.pi * 1.232

def test_func(a):
eval_list = [A0, lambda t: A1(a, t)]

def single_eval(idx, t):
return switch(idx, eval_list, t)

multiple_eval = vmap(single_eval, in_axes=(0, None))
idx_list = jnp.array([0, 1])
rhs = lambda y, t: multiple_eval(idx_list, t) * y

out = odeint(rhs, y0=y0, t=jnp.array([0, T], dtype=float), atol=1e-13, rtol=1e-13)
return out

jit(grad(lambda a: test_func(a)[-1][1].real))(1.0)
4 changes: 2 additions & 2 deletions test/dynamics/solvers/test_solver_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,8 +957,8 @@ def test_two_channel_SuperOp_simulation(self, model):
schedules=sched,
signals=signals,
test_tol=1e-8,
atol=1e-11,
rtol=1e-11,
atol=1e-12,
rtol=1e-12,
)

def test_4_channel_schedule(self):
Expand Down
18 changes: 8 additions & 10 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@ commands = stestr run {posargs}
[testenv:jax]
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.6
jaxlib<=0.4.6
equinox<=0.10.3
diffrax<=0.3.1
jax
jaxlib
diffrax

[testenv:lint]
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.6
jaxlib<=0.4.6
equinox<=0.10.3
diffrax<=0.3.1
jax
jaxlib
diffrax
commands =
black --check {posargs} qiskit_dynamics test
pylint -rn -j 0 --rcfile={toxinidir}/.pylintrc qiskit_dynamics/ test/
Expand All @@ -41,8 +39,8 @@ commands = black {posargs} qiskit_dynamics test
usedevelop = False
deps =
-r{toxinidir}/requirements-dev.txt
jax<=0.4.6
jaxlib<=0.4.6
jax
jaxlib
diffrax
commands =
sphinx-build -j auto -W -T --keep-going {posargs} docs/ docs/_build/html
Expand Down
Loading