diff --git a/docs/conf.py b/docs/conf.py index 28280d731..a96bf7418 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -75,8 +75,4 @@ # nbsphinx_execute = os.getenv('QISKIT_DOCS_BUILD_TUTORIALS', 'never') nbsphinx_execute = 'always' nbsphinx_widgets_path = '' -exclude_patterns = ['_build', '**.ipynb_checkpoints'] - -# this is tied to the temporary restriction to JAX versions <=0.4.6. See issue #190 -import os -os.environ["JAX_JIT_PJIT_API_MERGE"] = "0" \ No newline at end of file +exclude_patterns = ['_build', '**.ipynb_checkpoints'] \ No newline at end of file diff --git a/qiskit_dynamics/dispatch/backends/jax.py b/qiskit_dynamics/dispatch/backends/jax.py index ed104b9ef..0d5d55936 100644 --- a/qiskit_dynamics/dispatch/backends/jax.py +++ b/qiskit_dynamics/dispatch/backends/jax.py @@ -20,24 +20,6 @@ from jax import Array from jax.core import Tracer - # warning based on JAX version - from packaging import version - import warnings - - if version.parse(jax.__version__) >= version.parse("0.4.4"): - import os - - if ( - version.parse(jax.__version__) > version.parse("0.4.6") - or os.environ.get("JAX_JIT_PJIT_API_MERGE", None) != "0" - ): - warnings.warn( - "The functionality in the perturbation module of Qiskit Dynamics requires a JAX " - "version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, " - "0.4.5, and 0.4.6, using the perturbation module functionality requires setting " - "os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX or Dynamics." - ) - JAX_TYPES = (Array, Tracer) from ..dispatch import Dispatch diff --git a/qiskit_dynamics/perturbation/dyson_magnus.py b/qiskit_dynamics/perturbation/dyson_magnus.py index 7642bcd25..741c76fc6 100644 --- a/qiskit_dynamics/perturbation/dyson_magnus.py +++ b/qiskit_dynamics/perturbation/dyson_magnus.py @@ -57,7 +57,7 @@ try: import jax.numpy as jnp - from jax.lax import scan, switch + from jax.lax import scan from jax import vmap except ImportError: pass @@ -472,17 +472,15 @@ def _setup_dyson_rhs_jax( custom_matmul = _CustomMatmul(lmult_rule, index_offset=1, backend="jax") - perturbations_evaluation_order = jnp.array(perturbations_evaluation_order, dtype=int) + perturbations_evaluation_order = np.array(perturbations_evaluation_order, dtype=int) new_list = [generator] + perturbations - def single_eval(idx, t): - return switch(idx, new_list, t) - - multiple_eval = vmap(single_eval, in_axes=(0, None)) + def multiple_eval(t): + return jnp.array([new_list[idx](t) for idx in perturbations_evaluation_order]) def dyson_rhs(t, y): - return custom_matmul(multiple_eval(perturbations_evaluation_order, t), y) + return custom_matmul(multiple_eval(t), y) return dyson_rhs diff --git a/qiskit_dynamics/solvers/lanczos.py b/qiskit_dynamics/solvers/lanczos.py index 55012c729..f10364895 100644 --- a/qiskit_dynamics/solvers/lanczos.py +++ b/qiskit_dynamics/solvers/lanczos.py @@ -57,7 +57,7 @@ def lanczos_basis(A: Union[csr_matrix, np.ndarray], y0: np.ndarray, k_dim: int): q_basis[[0], :] = y0.T projection = A @ y0 - alpha[0] = y0.conj().T @ projection + alpha[0] = np.sum(y0.conj() * projection) projection = projection - alpha[0] * y0 beta[0] = np.linalg.norm(projection) diff --git a/releasenotes/notes/update-jax-a50ce1b7d6b47219.yaml b/releasenotes/notes/update-jax-a50ce1b7d6b47219.yaml new file mode 100644 index 000000000..cf80785d4 --- /dev/null +++ b/releasenotes/notes/update-jax-a50ce1b7d6b47219.yaml @@ -0,0 +1,10 @@ +--- +issues: + - | + A JAX warning about casting complex values to real is raised when computing gradients of + simulations in Qiskit Dynamics. Note that this warning does not appear to signify any error + in numerical computation, and can be safely ignored. +upgrade: + - | + The upper bound on JAX and Diffrax in the last version of Qiskit Dynamics has been removed. + Users should try to use the latest version of JAX. \ No newline at end of file diff --git a/setup.py b/setup.py index 937f9d879..a8e1fcd68 100644 --- a/setup.py +++ b/setup.py @@ -25,8 +25,7 @@ "arraylias" ] -jax_extras = ['jax>=0.4.0, <= 0.4.6', - 'jaxlib>=0.4.0, <= 0.4.6'] +jax_extras = ['jax', 'jaxlib'] PACKAGES = setuptools.find_packages(exclude=['test*']) diff --git a/test/dynamics/__init__.py b/test/dynamics/__init__.py index af0a00fef..be625db2a 100644 --- a/test/dynamics/__init__.py +++ b/test/dynamics/__init__.py @@ -13,8 +13,3 @@ """ Qiskit Dynamics tests """ - -# temporarily disable a change in JAX 0.4.4 that introduced a bug. Must be run before importing JAX -import os - -os.environ["JAX_JIT_PJIT_API_MERGE"] = "0" diff --git a/test/dynamics/common.py b/test/dynamics/common.py index a369bb6e2..f3d8d343e 100644 --- a/test/dynamics/common.py +++ b/test/dynamics/common.py @@ -286,7 +286,8 @@ def jit_grad_wrap(self, func_to_test: Callable) -> Callable: Args: func_to_test: The function whose gradient will be graded. Returns: - JIT-compiled gradient of function.""" + JIT-compiled gradient of function. + """ wf = wrap(lambda f: jit(grad(f)), decorator=True) f = lambda *args: np.sum(func_to_test(*args)).real return wf(f) diff --git a/test/dynamics/perturbation/test_dyson_magnus.py b/test/dynamics/perturbation/test_dyson_magnus.py index 3d812cba4..e63cea459 100644 --- a/test/dynamics/perturbation/test_dyson_magnus.py +++ b/test/dynamics/perturbation/test_dyson_magnus.py @@ -718,3 +718,51 @@ def assertMultRulesEqual(self, rule1, rule2): for sub_rule1, sub_rule2 in zip(rule1, rule2): self.assertAllClose(sub_rule1[0], sub_rule2[0]) self.assertAllClose(sub_rule1[1], sub_rule2[1]) + + +class TestWorkaround(QiskitDynamicsTestCase): + """Test whether workaround in dyson_magnus._setup_dyson_rhs_jax is no longer required. + + The workaround was introduced in the same commit as this test class to avoid an error being + raised by a non-trivial combination of JAX transformations. The test in this class has been + set up to expect the original minimal reproduction of the issue to fail. Once it no longer + fails, the changes made to _setup_dyson_rhs_jax in this commit should be reverted. + + See https://github.com/google/jax/discussions/9951#discussioncomment-2385157 for discussion of + issue. + """ + + def test_minimal_example(self): + """Test minimal reproduction of issue.""" + + with self.assertRaises(Exception): + import jax.numpy as jnp + from jax import grad, vmap + from jax.lax import switch + from jax.experimental.ode import odeint + + # pylint: disable=unused-argument + def A0(t): + return 2.0 + + # pylint: disable=unused-argument + def A1(a, t): + return a**2 + + y0 = np.random.rand(2) + T = np.pi * 1.232 + + def test_func(a): + eval_list = [A0, lambda t: A1(a, t)] + + def single_eval(idx, t): + return switch(idx, eval_list, t) + + multiple_eval = vmap(single_eval, in_axes=(0, None)) + idx_list = jnp.array([0, 1]) + rhs = lambda y, t: multiple_eval(idx_list, t) * y + + out = odeint(rhs, y0=y0, t=jnp.array([0, T], dtype=float), atol=1e-13, rtol=1e-13) + return out + + jit(grad(lambda a: test_func(a)[-1][1].real))(1.0) diff --git a/test/dynamics/solvers/test_solver_classes.py b/test/dynamics/solvers/test_solver_classes.py index f45507a6e..492d3cbdb 100644 --- a/test/dynamics/solvers/test_solver_classes.py +++ b/test/dynamics/solvers/test_solver_classes.py @@ -957,8 +957,8 @@ def test_two_channel_SuperOp_simulation(self, model): schedules=sched, signals=signals, test_tol=1e-8, - atol=1e-11, - rtol=1e-11, + atol=1e-12, + rtol=1e-12, ) def test_4_channel_schedule(self): diff --git a/tox.ini b/tox.ini index c134ac558..b81ffb106 100644 --- a/tox.ini +++ b/tox.ini @@ -15,18 +15,16 @@ commands = stestr run {posargs} [testenv:jax] deps = -r{toxinidir}/requirements-dev.txt - jax<=0.4.6 - jaxlib<=0.4.6 - equinox<=0.10.3 - diffrax<=0.3.1 + jax + jaxlib + diffrax [testenv:lint] deps = -r{toxinidir}/requirements-dev.txt - jax<=0.4.6 - jaxlib<=0.4.6 - equinox<=0.10.3 - diffrax<=0.3.1 + jax + jaxlib + diffrax commands = black --check {posargs} qiskit_dynamics test pylint -rn -j 0 --rcfile={toxinidir}/.pylintrc qiskit_dynamics/ test/ @@ -41,8 +39,8 @@ commands = black {posargs} qiskit_dynamics test usedevelop = False deps = -r{toxinidir}/requirements-dev.txt - jax<=0.4.6 - jaxlib<=0.4.6 + jax + jaxlib diffrax commands = sphinx-build -j auto -W -T --keep-going {posargs} docs/ docs/_build/html