diff --git a/qiskit_dynamics/solvers/diffrax_solver.py b/qiskit_dynamics/solvers/diffrax_solver.py new file mode 100644 index 000000000..fca57e97d --- /dev/null +++ b/qiskit_dynamics/solvers/diffrax_solver.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=invalid-name + +""" +Wrapper for diffrax solvers +""" + +from typing import Callable, Optional, Union, Tuple, List +from scipy.integrate._ivp.ivp import OdeResult +from qiskit import QiskitError + +from qiskit_dynamics.dispatch import requires_backend +from qiskit_dynamics.array import Array, wrap + +from .solver_utils import merge_t_args + + +try: + from diffrax import ODETerm, SaveAt + from diffrax import diffeqsolve as _diffeqsolve + + from diffrax.solver import AbstractSolver # pylint: disable=unused-import + import jax.numpy as jnp +except ImportError as err: + pass + + +@requires_backend("jax") +def diffrax_solver( + rhs: Callable, + t_span: Array, + y0: Array, + method: "AbstractSolver", + t_eval: Optional[Union[Tuple, List, Array]] = None, + **kwargs, +): + """Routine for calling ``diffrax.diffeqsolve`` + + Args: + rhs: Callable of the form :math:`f(t, y)`. + t_span: Interval to solve over. + y0: Initial state. + method: Which diffeq solving method to use. + t_eval: Optional list of time points at which to return the solution. + **kwargs: Optional arguments to be passed to ``diffeqsolve``. + + Returns: + OdeResult: Results object. + + Raises: + QiskitError: Passing both `SaveAt` argument and `t_eval` argument. + """ + + t_list = merge_t_args(t_span, t_eval) + + # convert rhs and y0 to real + rhs = real_rhs(rhs) + y0 = c2r(y0) + + term = ODETerm(lambda t, y, _: Array(rhs(t.real, y), dtype=float).data) + + diffeqsolve = wrap(_diffeqsolve) + + if "saveat" in kwargs and t_eval is not None: + raise QiskitError( + """Only one of t_eval or saveat can be passed when using + a diffrax solver, but both were specified.""" + ) + + if t_eval is not None: + kwargs["saveat"] = SaveAt(ts=t_eval) + + results = diffeqsolve( + term, + solver=method, + t0=t_list[0], + t1=t_list[-1], + dt0=None, + y0=Array(y0, dtype=float), + **kwargs, + ) + + sol_dict = vars(results) + ys = sol_dict.pop("ys") + + ys = jnp.swapaxes(r2c(jnp.swapaxes(ys, 0, 1)), 0, 1) + + results_out = OdeResult(t=t_eval, y=Array(ys, backend="jax", dtype=complex), **sol_dict) + + return results_out + + +def real_rhs(rhs): + """Convert complex RHS to real RHS function""" + + def _real_rhs(t, y): + return c2r(rhs(t, r2c(y))) + + return _real_rhs + + +def c2r(arr): + """Convert complex array to a real array""" + return jnp.concatenate([jnp.real(Array(arr).data), jnp.imag(Array(arr).data)]) + + +def r2c(arr): + """Convert a real array to a complex array""" + size = arr.shape[0] // 2 + return arr[:size] + 1j * arr[size:] diff --git a/qiskit_dynamics/solvers/solver_functions.py b/qiskit_dynamics/solvers/solver_functions.py index 653eaf5b1..bea45b87f 100644 --- a/qiskit_dynamics/solvers/solver_functions.py +++ b/qiskit_dynamics/solvers/solver_functions.py @@ -43,6 +43,14 @@ ) from .scipy_solve_ivp import scipy_solve_ivp, SOLVE_IVP_METHODS from .jax_odeint import jax_odeint +from .diffrax_solver import diffrax_solver + +try: + from diffrax.solver import AbstractSolver + + diffrax_installed = True +except ImportError: + diffrax_installed = False ODE_METHODS = ( ["RK45", "RK23", "BDF", "DOP853", "Radau", "LSODA"] # scipy solvers @@ -56,7 +64,7 @@ def solve_ode( rhs: Union[Callable, BaseGeneratorModel], t_span: Array, y0: Array, - method: Optional[Union[str, OdeSolver]] = "DOP853", + method: Optional[Union[str, OdeSolver, "AbstractSolver"]] = "DOP853", t_eval: Optional[Union[Tuple, List, Array]] = None, **kwargs, ): @@ -106,7 +114,8 @@ def solve_ode( """ if method not in ODE_METHODS and not ( - isinstance(method, type) and issubclass(method, OdeSolver) + (isinstance(method, type) and (issubclass(method, OdeSolver))) + or (diffrax_installed and isinstance(method, AbstractSolver)) ): raise QiskitError("Method " + str(method) + " not supported by solve_ode.") @@ -122,6 +131,8 @@ def solve_ode( # solve the problem using specified method if method in SOLVE_IVP_METHODS or (isinstance(method, type) and issubclass(method, OdeSolver)): results = scipy_solve_ivp(solver_rhs, t_span, y0, method, t_eval=t_eval, **kwargs) + elif diffrax_installed and isinstance(method, AbstractSolver): + results = diffrax_solver(solver_rhs, t_span, y0, method=method, t_eval=t_eval, **kwargs) elif isinstance(method, str) and method == "RK4": results = RK4_solver(solver_rhs, t_span, y0, t_eval=t_eval, **kwargs) elif isinstance(method, str) and method == "jax_RK4": @@ -144,7 +155,7 @@ def solve_lmde( generator: Union[Callable, BaseGeneratorModel], t_span: Array, y0: Array, - method: Optional[Union[str, OdeSolver]] = "DOP853", + method: Optional[Union[str, OdeSolver, "AbstractSolver"]] = "DOP853", t_eval: Optional[Union[Tuple, List, Array]] = None, **kwargs, ): @@ -223,7 +234,13 @@ def solve_lmde( """ # delegate to solve_ode if necessary - if method in ODE_METHODS or (isinstance(method, type) and issubclass(method, OdeSolver)): + if method in ODE_METHODS or ( + isinstance(method, type) + and ( + issubclass(method, OdeSolver) + or (diffrax_installed and issubclass(method, AbstractSolver)) + ) + ): if isinstance(generator, BaseGeneratorModel): rhs = generator else: diff --git a/releasenotes/notes/add-diffrax-solvers-946869d5a304318a.yaml b/releasenotes/notes/add-diffrax-solvers-946869d5a304318a.yaml new file mode 100644 index 000000000..0920dbac7 --- /dev/null +++ b/releasenotes/notes/add-diffrax-solvers-946869d5a304318a.yaml @@ -0,0 +1,17 @@ +--- +features: + - | + Added support for solvers from the diffrax package: + https://github.com/patrick-kidger/diffrax. A new option + is enabled to pass in an object -- a solver from diffrax + instead of a string for a jax or scipy solver, for example:: + + from diffrax import Dopri5 + from qiskit-dynamics import solve_ode + + sol = solve_ode( + rhs: some_function, + t_span: some_t_span, + y0: some_initial_conditions, + method: Dopri5() + ) diff --git a/test/dynamics/common.py b/test/dynamics/common.py index 0c2c86b5f..3629d2c9c 100644 --- a/test/dynamics/common.py +++ b/test/dynamics/common.py @@ -101,6 +101,23 @@ def jit_grad_wrap(self, func_to_test: Callable) -> Callable: return wf(f) +class TestDiffraxBase(unittest.TestCase): + """Base class with setUpClass and tearDownClass for importing diffrax solvers + + Test cases that inherit from this class will automatically work with diffrax solvers + backend. + """ + + @classmethod + def setUpClass(cls): + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import diffrax # pylint: disable=import-outside-toplevel,unused-import + except Exception as err: + raise unittest.SkipTest("Skipping diffrax tests.") from err + + class TestQutipBase(unittest.TestCase): """Base class for tests that utilize Qutip.""" diff --git a/test/dynamics/solvers/test_diffrax_DOP5.py b/test/dynamics/solvers/test_diffrax_DOP5.py new file mode 100644 index 000000000..540ece23b --- /dev/null +++ b/test/dynamics/solvers/test_diffrax_DOP5.py @@ -0,0 +1,160 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2020. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=invalid-name + +""" +Direct tests of diffrax_solver +""" + +import numpy as np + +from qiskit_dynamics.solvers.diffrax_solver import diffrax_solver + +from ..common import QiskitDynamicsTestCase, TestJaxBase + +try: + import jax.numpy as jnp + from jax.lax import cond + from diffrax import Dopri5, PIDController +# pylint: disable=broad-except +except Exception: + pass + + +class TestDiffraxDopri5(QiskitDynamicsTestCase, TestJaxBase): + """Test cases for diffrax_solver.""" + + def setUp(self): + + # pylint: disable=unused-argument + def simple_rhs(t, y): + return cond(t < 1.0, lambda s: s, lambda s: s**2, jnp.array([t])) + + self.simple_rhs = simple_rhs + + def test_t_eval_arg_no_overlap(self): + """Test handling of t_eval when no overlap with t_span.""" + + t_span = np.array([0.0, 2.0]) + t_eval = np.array([1.0, 1.5, 1.7]) + y0 = jnp.array([1.0]) + + stepsize_controller = PIDController(rtol=1e-10, atol=1e-10) + results = diffrax_solver( + self.simple_rhs, + t_span, + y0, + method=Dopri5(), + t_eval=t_eval, + stepsize_controller=stepsize_controller, + ) + + self.assertAllClose(t_eval, results.t) + + expected_y = jnp.array( + [ + [1 + 0.5], + [1 + 0.5 + (1.5**3 - 1.0**3) / 3], + [1 + 0.5 + (1.7**3 - 1.0**3) / 3], + ] + ) + + self.assertAllClose(expected_y, results.y) + + def test_t_eval_arg_no_overlap_backwards(self): + """Test handling of t_eval when no overlap with t_span with backwards integration.""" + + t_span = np.array([2.0, 0.0]) + t_eval = np.array([1.7, 1.5, 1.0]) + y0 = jnp.array([1 + 0.5 + (2.0**3 - 1.0**3) / 3]) + + stepsize_controller = PIDController(rtol=1e-10, atol=1e-10) + results = diffrax_solver( + self.simple_rhs, + t_span, + y0, + method=Dopri5(), + t_eval=t_eval, + stepsize_controller=stepsize_controller, + ) + + self.assertAllClose(t_eval, results.t) + + expected_y = jnp.array( + [ + [1 + 0.5 + (1.7**3 - 1.0**3) / 3], + [1 + 0.5 + (1.5**3 - 1.0**3) / 3], + [1 + 0.5], + ] + ) + + self.assertAllClose(expected_y, results.y) + + def test_t_eval_arg_overlap(self): + """Test handling of t_eval with overlap with t_span.""" + + t_span = np.array([0.0, 2.0]) + t_eval = np.array([1.0, 1.5, 1.7, 2.0]) + y0 = jnp.array([1.0]) + + stepsize_controller = PIDController(rtol=1e-10, atol=1e-10) + results = diffrax_solver( + self.simple_rhs, + t_span, + y0, + method=Dopri5(), + t_eval=t_eval, + stepsize_controller=stepsize_controller, + ) + + self.assertAllClose(t_eval, results.t) + + expected_y = jnp.array( + [ + [1 + 0.5], + [1 + 0.5 + (1.5**3 - 1.0**3) / 3], + [1 + 0.5 + (1.7**3 - 1.0**3) / 3], + [1 + 0.5 + (2**3 - 1.0**3) / 3], + ] + ) + + self.assertAllClose(expected_y, results.y) + + def test_t_eval_arg_overlap_backwards(self): + """Test handling of t_eval with overlap with t_span with backwards integration.""" + + t_span = np.array([2.0, 0.0]) + t_eval = np.array([2.0, 1.7, 1.5, 1.0]) + y0 = jnp.array([1 + 0.5 + (2.0**3 - 1.0**3) / 3]) + + stepsize_controller = PIDController(rtol=1e-10, atol=1e-10) + results = diffrax_solver( + self.simple_rhs, + t_span, + y0, + method=Dopri5(), + t_eval=t_eval, + stepsize_controller=stepsize_controller, + ) + + self.assertAllClose(t_eval, results.t) + + expected_y = jnp.array( + [ + [1 + 0.5 + (2**3 - 1.0**3) / 3], + [1 + 0.5 + (1.7**3 - 1.0**3) / 3], + [1 + 0.5 + (1.5**3 - 1.0**3) / 3], + [1 + 0.5], + ] + ) + + self.assertAllClose(expected_y, results.y) diff --git a/test/dynamics/solvers/test_solver_functions.py b/test/dynamics/solvers/test_solver_functions.py index 891be7c48..e7765c86c 100644 --- a/test/dynamics/solvers/test_solver_functions.py +++ b/test/dynamics/solvers/test_solver_functions.py @@ -29,7 +29,12 @@ from qiskit_dynamics import solve_ode, solve_lmde from qiskit_dynamics.array import Array -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import QiskitDynamicsTestCase, TestDiffraxBase, TestJaxBase + +try: + from diffrax import PIDController, Tsit5, Dopri5 +except ImportError: + pass class TestSolverMethod(ABC, QiskitDynamicsTestCase): @@ -140,6 +145,18 @@ def test_basic_model(self): self.assertAllClose(results.y[-1], expected, atol=self.tol, rtol=self.tol) + def test_backwards_solving(self): + """Test case for reversed basic model.""" + + reverse_t_span = self.t_span.copy() + reverse_t_span.reverse() + + reverse_y0 = expm(-1j * np.pi * self.X.data) + + results = self.solve(self.basic_rhs, t_span=reverse_t_span, y0=reverse_y0) + + self.assertAllClose(results.y[-1], self.y0, atol=self.tol, rtol=self.tol) + def test_w_GeneratorModel(self): """Test on a GeneratorModel.""" @@ -392,5 +409,45 @@ def is_ode_method(self): return True +class Testdiffrax_DOP5(TestSolverMethodJax, TestDiffraxBase): + """Tests for diffrax Dopri5 method.""" + + def solve(self, rhs, t_span, y0, t_eval=None, **kwargs): + stepsize_controller = PIDController(atol=1e-10, rtol=1e-10) + return solve_ode( + rhs=rhs, + t_span=t_span, + y0=y0, + method=Dopri5(), + t_eval=t_eval, + stepsize_controller=stepsize_controller, + **kwargs, + ) + + @property + def is_ode_method(self): + return True + + +class Testdiffrax_Tsit5(TestSolverMethodJax, TestDiffraxBase): + """Tests for diffrax Tsit5 method.""" + + def solve(self, rhs, t_span, y0, t_eval=None, **kwargs): + stepsize_controller = PIDController(atol=1e-10, rtol=1e-10) + return solve_ode( + rhs=rhs, + t_span=t_span, + y0=y0, + method=Tsit5(), + t_eval=t_eval, + stepsize_controller=stepsize_controller, + **kwargs, + ) + + @property + def is_ode_method(self): + return True + + # delete abstract classes so unittest doesn't attempt to run them del TestSolverMethod, TestSolverMethodJax diff --git a/tox.ini b/tox.ini index 147f81476..fbe9cc299 100644 --- a/tox.ini +++ b/tox.ini @@ -17,12 +17,14 @@ deps = -r{toxinidir}/requirements-dev.txt jax jaxlib + diffrax [testenv:lint] deps = -r{toxinidir}/requirements-dev.txt jax jaxlib + diffrax commands = black --check {posargs} qiskit_dynamics test pylint -rn -j 0 --rcfile={toxinidir}/.pylintrc qiskit_dynamics/ test/ @@ -37,6 +39,7 @@ deps = -r{toxinidir}/requirements-dev.txt jax jaxlib + diffrax commands = sphinx-build -b html -W {posargs} docs/ docs/_build/html