Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add diffrax solvers #104

Merged
merged 53 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b4a5a4b
diffrax solver testing
Apr 5, 2022
622ec10
Merge branch 'Qiskit:main' into diffrax
brosand Apr 5, 2022
eef98b5
adding some diffrax testing
Apr 6, 2022
fae5f2c
diffrax changes not working
Apr 25, 2022
2241387
about to remove old stuff from diffrax solver
Apr 27, 2022
45f273d
solver running with solve_ode
Apr 28, 2022
b55ebc4
diffrax solver test working
Apr 29, 2022
b51c986
black run and tests still working (still need to figure out t_direction)
Apr 29, 2022
096ec99
Merge branch 'Qiskit:main' into diffrax
brosand Apr 29, 2022
3077cd8
adding kwargs to diffrax solver by removing tol from kwargs
Apr 29, 2022
e1e3d5c
allow user to pass stepsize_controller
May 9, 2022
1ff510d
adding test for reversible differentiation
May 9, 2022
e16d550
adding diffrax testing without diffrax dependence
May 10, 2022
c496e39
Merge branch 'Qiskit:main' into diffrax
brosand May 11, 2022
6fdedc1
removing extra files
May 16, 2022
38508cb
removing gitignore changes
May 16, 2022
243cbf6
making diffrax solver mandatory
May 16, 2022
5ad0bca
removed Dopri5 import
May 16, 2022
6ec8a61
making diffrax work without install
May 17, 2022
b9f12c9
removing 1 boolean expression to solve lint error
May 17, 2022
57009b1
adding some more diffrax tests
May 17, 2022
1c2888e
linting
May 17, 2022
895c38d
adding releasenote
May 17, 2022
3416a44
slight readability change to release notes
May 17, 2022
6f1e51b
adding example to releasenotes
May 17, 2022
b6a1977
diffrax import bug potential solve
May 18, 2022
3cbad7b
diffrax import bug potentially fixed
May 18, 2022
afc53fa
linting and diffrax import stuff
May 18, 2022
3f94628
get diffrax wasn't being called in solver function one part
May 18, 2022
b7d1056
Revert "get diffrax wasn't being called in solver function one part"
May 18, 2022
2ea3c11
Revert "linting and diffrax import stuff"
May 18, 2022
8b1e1c4
Revert "diffrax import bug potentially fixed"
May 18, 2022
d9c6db7
Revert "diffrax import bug potential solve"
May 18, 2022
e6b871b
reverting get_diffrax and using new version of diffrax
May 18, 2022
4b0a6fa
Merge branch 'Qiskit:main' into diffrax
brosand May 18, 2022
d289342
Update tox.ini for diffrax update
brosand May 19, 2022
6d28544
removing atol from diffrax solver and using stepsize_controller
May 20, 2022
59239e2
add error for saveat and clean docs
May 20, 2022
1cb6c45
slight lint changes
May 20, 2022
e1b5054
changed saveat to take t_list not t_eval
May 20, 2022
3f39f58
docs
May 20, 2022
9f199fc
fixing t_eval stuff
May 20, 2022
2278ba6
removing class diffrax solver definitions
May 23, 2022
5c324b7
removing diffrax test dopri5 import
May 23, 2022
dc385ca
making type hint in diffrax solver not a string, see if that breaks t…
May 23, 2022
6caba34
importing abstract solver in diffrax_solver
May 24, 2022
a4ecc2f
trying import abstract solver for diffrax_solver
May 24, 2022
97af46b
abstract solver string again, disabling unused import
May 25, 2022
61d9666
slight change to solve definition list ends without blank line
May 25, 2022
b05e790
slight change to solve definition list ends without blank line
May 25, 2022
abe7223
Merge branch 'Qiskit:main' into diffrax
brosand Jun 3, 2022
d4d1476
diffrax slight changes
Jun 3, 2022
ce49721
switch to swapaxes
Jun 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions qiskit_dynamics/solvers/diffrax_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name

"""
Wrapper for diffrax solvers
"""

from typing import Callable, Optional, Union, Tuple, List
from scipy.integrate._ivp.ivp import OdeResult
from qiskit import QiskitError

from qiskit_dynamics.dispatch import requires_backend
from qiskit_dynamics.array import Array, wrap

from .solver_utils import merge_t_args


try:
from diffrax import ODETerm, SaveAt
from diffrax import diffeqsolve as _diffeqsolve

from diffrax.solver import AbstractSolver # pylint: disable=unused-import
import jax.numpy as jnp
except ImportError as err:
pass


@requires_backend("jax")
def diffrax_solver(
rhs: Callable,
t_span: Array,
y0: Array,
method: "AbstractSolver",
t_eval: Optional[Union[Tuple, List, Array]] = None,
**kwargs,
):
"""Routine for calling ``diffrax.diffeqsolve``

Args:
rhs: Callable of the form :math:`f(t, y)`.
t_span: Interval to solve over.
y0: Initial state.
method: Which diffeq solving method to use.
t_eval: Optional list of time points at which to return the solution.
**kwargs: Optional arguments to be passed to ``diffeqsolve``.

Returns:
OdeResult: Results object.

Raises:
QiskitError: Passing both `SaveAt` argument and `t_eval` argument.
"""

t_list = merge_t_args(t_span, t_eval)

# convert rhs and y0 to real
rhs = real_rhs(rhs)
y0 = c2r(y0)

term = ODETerm(lambda t, y, _: Array(rhs(t.real, y), dtype=float).data)

diffeqsolve = wrap(_diffeqsolve)

if "saveat" in kwargs and t_eval is not None:
raise QiskitError(
brosand marked this conversation as resolved.
Show resolved Hide resolved
"""Only one of t_eval or saveat can be passed when using
a diffrax solver, but both were specified."""
)

if t_eval is not None:
brosand marked this conversation as resolved.
Show resolved Hide resolved
kwargs["saveat"] = SaveAt(ts=t_eval)

results = diffeqsolve(
term,
solver=method,
t0=t_list[0],
t1=t_list[-1],
dt0=None,
y0=Array(y0, dtype=float),
**kwargs,
)

sol_dict = vars(results)
ys = sol_dict.pop("ys")

ys = jnp.swapaxes(r2c(jnp.swapaxes(ys, 0, 1)), 0, 1)

results_out = OdeResult(t=t_eval, y=Array(ys, backend="jax", dtype=complex), **sol_dict)

return results_out


def real_rhs(rhs):
brosand marked this conversation as resolved.
Show resolved Hide resolved
"""Convert complex RHS to real RHS function"""

def _real_rhs(t, y):
return c2r(rhs(t, r2c(y)))

return _real_rhs


def c2r(arr):
"""Convert complex array to a real array"""
return jnp.concatenate([jnp.real(Array(arr).data), jnp.imag(Array(arr).data)])


def r2c(arr):
"""Convert a real array to a complex array"""
size = arr.shape[0] // 2
return arr[:size] + 1j * arr[size:]
25 changes: 21 additions & 4 deletions qiskit_dynamics/solvers/solver_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
)
from .scipy_solve_ivp import scipy_solve_ivp, SOLVE_IVP_METHODS
from .jax_odeint import jax_odeint
from .diffrax_solver import diffrax_solver

try:
from diffrax.solver import AbstractSolver

diffrax_installed = True
except ImportError:
diffrax_installed = False

ODE_METHODS = (
["RK45", "RK23", "BDF", "DOP853", "Radau", "LSODA"] # scipy solvers
Expand All @@ -56,7 +64,7 @@ def solve_ode(
rhs: Union[Callable, BaseGeneratorModel],
t_span: Array,
y0: Array,
method: Optional[Union[str, OdeSolver]] = "DOP853",
method: Optional[Union[str, OdeSolver, "AbstractSolver"]] = "DOP853",
t_eval: Optional[Union[Tuple, List, Array]] = None,
**kwargs,
):
Expand Down Expand Up @@ -106,7 +114,8 @@ def solve_ode(
"""

if method not in ODE_METHODS and not (
isinstance(method, type) and issubclass(method, OdeSolver)
(isinstance(method, type) and (issubclass(method, OdeSolver)))
or (diffrax_installed and isinstance(method, AbstractSolver))
):
raise QiskitError("Method " + str(method) + " not supported by solve_ode.")

Expand All @@ -122,6 +131,8 @@ def solve_ode(
# solve the problem using specified method
if method in SOLVE_IVP_METHODS or (isinstance(method, type) and issubclass(method, OdeSolver)):
results = scipy_solve_ivp(solver_rhs, t_span, y0, method, t_eval=t_eval, **kwargs)
elif diffrax_installed and isinstance(method, AbstractSolver):
results = diffrax_solver(solver_rhs, t_span, y0, method=method, t_eval=t_eval, **kwargs)
elif isinstance(method, str) and method == "RK4":
results = RK4_solver(solver_rhs, t_span, y0, t_eval=t_eval, **kwargs)
elif isinstance(method, str) and method == "jax_RK4":
Expand All @@ -144,7 +155,7 @@ def solve_lmde(
generator: Union[Callable, BaseGeneratorModel],
t_span: Array,
y0: Array,
method: Optional[Union[str, OdeSolver]] = "DOP853",
method: Optional[Union[str, OdeSolver, "AbstractSolver"]] = "DOP853",
t_eval: Optional[Union[Tuple, List, Array]] = None,
**kwargs,
):
Expand Down Expand Up @@ -223,7 +234,13 @@ def solve_lmde(
"""

# delegate to solve_ode if necessary
if method in ODE_METHODS or (isinstance(method, type) and issubclass(method, OdeSolver)):
if method in ODE_METHODS or (
isinstance(method, type)
and (
issubclass(method, OdeSolver)
or (diffrax_installed and issubclass(method, AbstractSolver))
)
):
if isinstance(generator, BaseGeneratorModel):
rhs = generator
else:
Expand Down
17 changes: 17 additions & 0 deletions releasenotes/notes/add-diffrax-solvers-946869d5a304318a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
---
features:
- |
Added support for solvers from the diffrax package:
https://github.com/patrick-kidger/diffrax. A new option
is enabled to pass in an object -- a solver from diffrax
instead of a string for a jax or scipy solver, for example::

from diffrax import Dopri5
from qiskit-dynamics import solve_ode

sol = solve_ode(
brosand marked this conversation as resolved.
Show resolved Hide resolved
rhs: some_function,
t_span: some_t_span,
y0: some_initial_conditions,
method: Dopri5()
)
17 changes: 17 additions & 0 deletions test/dynamics/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,23 @@ def jit_grad_wrap(self, func_to_test: Callable) -> Callable:
return wf(f)


class TestDiffraxBase(unittest.TestCase):
"""Base class with setUpClass and tearDownClass for importing diffrax solvers

Test cases that inherit from this class will automatically work with diffrax solvers
backend.
"""

@classmethod
def setUpClass(cls):
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import diffrax # pylint: disable=import-outside-toplevel,unused-import
except Exception as err:
raise unittest.SkipTest("Skipping diffrax tests.") from err


class TestQutipBase(unittest.TestCase):
"""Base class for tests that utilize Qutip."""

Expand Down
160 changes: 160 additions & 0 deletions test/dynamics/solvers/test_diffrax_DOP5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2020.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name

"""
Direct tests of diffrax_solver
"""

import numpy as np

from qiskit_dynamics.solvers.diffrax_solver import diffrax_solver

from ..common import QiskitDynamicsTestCase, TestJaxBase

try:
import jax.numpy as jnp
from jax.lax import cond
from diffrax import Dopri5, PIDController
# pylint: disable=broad-except
except Exception:
pass


class TestDiffraxDopri5(QiskitDynamicsTestCase, TestJaxBase):
"""Test cases for diffrax_solver."""

def setUp(self):

# pylint: disable=unused-argument
def simple_rhs(t, y):
return cond(t < 1.0, lambda s: s, lambda s: s**2, jnp.array([t]))

self.simple_rhs = simple_rhs

def test_t_eval_arg_no_overlap(self):
"""Test handling of t_eval when no overlap with t_span."""

t_span = np.array([0.0, 2.0])
t_eval = np.array([1.0, 1.5, 1.7])
y0 = jnp.array([1.0])

stepsize_controller = PIDController(rtol=1e-10, atol=1e-10)
results = diffrax_solver(
self.simple_rhs,
t_span,
y0,
method=Dopri5(),
t_eval=t_eval,
stepsize_controller=stepsize_controller,
)

self.assertAllClose(t_eval, results.t)

expected_y = jnp.array(
[
[1 + 0.5],
[1 + 0.5 + (1.5**3 - 1.0**3) / 3],
[1 + 0.5 + (1.7**3 - 1.0**3) / 3],
]
)

self.assertAllClose(expected_y, results.y)

def test_t_eval_arg_no_overlap_backwards(self):
"""Test handling of t_eval when no overlap with t_span with backwards integration."""

t_span = np.array([2.0, 0.0])
t_eval = np.array([1.7, 1.5, 1.0])
y0 = jnp.array([1 + 0.5 + (2.0**3 - 1.0**3) / 3])

stepsize_controller = PIDController(rtol=1e-10, atol=1e-10)
results = diffrax_solver(
self.simple_rhs,
t_span,
y0,
method=Dopri5(),
t_eval=t_eval,
stepsize_controller=stepsize_controller,
)

self.assertAllClose(t_eval, results.t)

expected_y = jnp.array(
[
[1 + 0.5 + (1.7**3 - 1.0**3) / 3],
[1 + 0.5 + (1.5**3 - 1.0**3) / 3],
[1 + 0.5],
]
)

self.assertAllClose(expected_y, results.y)

def test_t_eval_arg_overlap(self):
"""Test handling of t_eval with overlap with t_span."""

t_span = np.array([0.0, 2.0])
t_eval = np.array([1.0, 1.5, 1.7, 2.0])
y0 = jnp.array([1.0])

stepsize_controller = PIDController(rtol=1e-10, atol=1e-10)
results = diffrax_solver(
self.simple_rhs,
t_span,
y0,
method=Dopri5(),
t_eval=t_eval,
stepsize_controller=stepsize_controller,
)

self.assertAllClose(t_eval, results.t)

expected_y = jnp.array(
[
[1 + 0.5],
[1 + 0.5 + (1.5**3 - 1.0**3) / 3],
[1 + 0.5 + (1.7**3 - 1.0**3) / 3],
[1 + 0.5 + (2**3 - 1.0**3) / 3],
]
)

self.assertAllClose(expected_y, results.y)

def test_t_eval_arg_overlap_backwards(self):
"""Test handling of t_eval with overlap with t_span with backwards integration."""

t_span = np.array([2.0, 0.0])
t_eval = np.array([2.0, 1.7, 1.5, 1.0])
y0 = jnp.array([1 + 0.5 + (2.0**3 - 1.0**3) / 3])

stepsize_controller = PIDController(rtol=1e-10, atol=1e-10)
results = diffrax_solver(
self.simple_rhs,
t_span,
y0,
method=Dopri5(),
t_eval=t_eval,
stepsize_controller=stepsize_controller,
)

self.assertAllClose(t_eval, results.t)

expected_y = jnp.array(
[
[1 + 0.5 + (2**3 - 1.0**3) / 3],
[1 + 0.5 + (1.7**3 - 1.0**3) / 3],
[1 + 0.5 + (1.5**3 - 1.0**3) / 3],
[1 + 0.5],
]
)

self.assertAllClose(expected_y, results.y)
Loading