Skip to content

Commit

Permalink
Updating solvers module to use arraylias (#281)
Browse files Browse the repository at this point in the history
Co-authored-by: Kento Ueda <[email protected]>
  • Loading branch information
DanPuzzuoli and to24toro committed Feb 5, 2024
1 parent 8e5d884 commit 5558ed6
Show file tree
Hide file tree
Showing 19 changed files with 520 additions and 423 deletions.
2 changes: 1 addition & 1 deletion qiskit_dynamics/models/hamiltonian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def is_hermitian(operator: ArrayLike, tol: Optional[float] = 1e-10) -> bool:
return is_hermitian(operator.todense())
elif isinstance(operator, ArrayLike):
adj = None
adj = np.transpose(np.conjugate(operator))
adj = unp.transpose(unp.conjugate(operator))
return np.linalg.norm(adj - operator) < tol

raise QiskitError("is_hermitian got an unexpected type.")
22 changes: 10 additions & 12 deletions qiskit_dynamics/solvers/diffrax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
Wrapper for diffrax solvers
"""

from typing import Callable, Optional, Union, Tuple, List
from typing import Callable, Optional
from scipy.integrate._ivp.ivp import OdeResult
from qiskit import QiskitError

from qiskit_dynamics.arraylias import ArrayLike
from qiskit_dynamics.dispatch import requires_backend
from qiskit_dynamics.array import Array, wrap

try:
import jax.numpy as jnp
Expand All @@ -33,10 +33,10 @@
@requires_backend("jax")
def diffrax_solver(
rhs: Callable,
t_span: Array,
y0: Array,
t_span: ArrayLike,
y0: ArrayLike,
method: "AbstractSolver",
t_eval: Optional[Union[Tuple, List, Array]] = None,
t_eval: Optional[ArrayLike] = None,
**kwargs,
):
"""Routine for calling ``diffrax.diffeqsolve``
Expand All @@ -57,15 +57,13 @@ def diffrax_solver(
"""

from diffrax import ODETerm, SaveAt
from diffrax import diffeqsolve as _diffeqsolve

diffeqsolve = wrap(_diffeqsolve)
from diffrax import diffeqsolve

# convert rhs and y0 to real
rhs = real_rhs(rhs)
y0 = c2r(y0)

term = ODETerm(lambda t, y, _: Array(rhs(t.real, y), dtype=float).data)
term = ODETerm(lambda t, y, _: rhs(t.real, y))

if "saveat" in kwargs and t_eval is not None:
raise QiskitError(
Expand All @@ -82,7 +80,7 @@ def diffrax_solver(
t0=t_span[0],
t1=t_span[-1],
dt0=None,
y0=Array(y0, dtype=float),
y0=jnp.array(y0, dtype=float),
**kwargs,
)

Expand All @@ -92,7 +90,7 @@ def diffrax_solver(

ys = jnp.swapaxes(r2c(jnp.swapaxes(ys, 0, 1)), 0, 1)

results_out = OdeResult(t=ts, y=Array(ys, backend="jax", dtype=complex), **sol_dict)
results_out = OdeResult(t=ts, y=jnp.array(ys, dtype=complex), **sol_dict)

return results_out

Expand All @@ -108,7 +106,7 @@ def _real_rhs(t, y):

def c2r(arr):
"""Convert complex array to a real array"""
return jnp.concatenate([jnp.real(Array(arr).data), jnp.imag(Array(arr).data)])
return jnp.concatenate([jnp.real(arr), jnp.imag(arr)])


def r2c(arr):
Expand Down
Loading

0 comments on commit 5558ed6

Please sign in to comment.