Skip to content

Commit

Permalink
upgrading diffrax
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli committed Oct 30, 2023
1 parent a8ee2cf commit 6968818
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
20 changes: 9 additions & 11 deletions qiskit_dynamics/solvers/diffrax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from scipy.integrate._ivp.ivp import OdeResult
from qiskit import QiskitError

from qiskit_dynamics.arraylias import ArrayLike
from qiskit_dynamics.dispatch import requires_backend
from qiskit_dynamics.array import Array, wrap

try:
import jax.numpy as jnp
Expand All @@ -33,10 +33,10 @@
@requires_backend("jax")
def diffrax_solver(
rhs: Callable,
t_span: Array,
y0: Array,
t_span: ArrayLike,
y0: ArrayLike,
method: "AbstractSolver",
t_eval: Optional[Union[Tuple, List, Array]] = None,
t_eval: Optional[ArrayLike] = None,
**kwargs,
):
"""Routine for calling ``diffrax.diffeqsolve``
Expand All @@ -57,15 +57,13 @@ def diffrax_solver(
"""

from diffrax import ODETerm, SaveAt
from diffrax import diffeqsolve as _diffeqsolve

diffeqsolve = wrap(_diffeqsolve)
from diffrax import diffeqsolve

# convert rhs and y0 to real
rhs = real_rhs(rhs)
y0 = c2r(y0)

term = ODETerm(lambda t, y, _: Array(rhs(t.real, y), dtype=float).data)
term = ODETerm(lambda t, y, _: rhs(t.real, y))

if "saveat" in kwargs and t_eval is not None:
raise QiskitError(
Expand All @@ -82,7 +80,7 @@ def diffrax_solver(
t0=t_span[0],
t1=t_span[-1],
dt0=None,
y0=Array(y0, dtype=float),
y0=jnp.array(y0, dtype=float),
**kwargs,
)

Expand All @@ -92,7 +90,7 @@ def diffrax_solver(

ys = jnp.swapaxes(r2c(jnp.swapaxes(ys, 0, 1)), 0, 1)

results_out = OdeResult(t=ts, y=Array(ys, backend="jax", dtype=complex), **sol_dict)
results_out = OdeResult(t=ts, y=jnp.array(ys, dtype=complex), **sol_dict)

return results_out

Expand All @@ -108,7 +106,7 @@ def _real_rhs(t, y):

def c2r(arr):
"""Convert complex array to a real array"""
return jnp.concatenate([jnp.real(Array(arr).data), jnp.imag(Array(arr).data)])
return jnp.concatenate([jnp.real(arr), jnp.imag(arr)])


def r2c(arr):
Expand Down
4 changes: 2 additions & 2 deletions test/dynamics/solvers/test_diffrax_DOP5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from qiskit_dynamics.solvers.diffrax_solver import diffrax_solver

from ..common import QiskitDynamicsTestCase, TestJaxBase
from ..common import JAXTestBase

try:
import jax.numpy as jnp
Expand All @@ -30,7 +30,7 @@
pass


class TestDiffraxDopri5(QiskitDynamicsTestCase, TestJaxBase):
class TestDiffraxDopri5(JAXTestBase):
"""Test cases for diffrax_solver."""

def setUp(self):
Expand Down

0 comments on commit 6968818

Please sign in to comment.