Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating solvers module to use arraylias #281

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qiskit_dynamics/models/hamiltonian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def is_hermitian(operator: ArrayLike, tol: Optional[float] = 1e-10) -> bool:
return is_hermitian(operator.todense())
elif isinstance(operator, ArrayLike):
adj = None
adj = np.transpose(np.conjugate(operator))
adj = unp.transpose(unp.conjugate(operator))
return np.linalg.norm(adj - operator) < tol

raise QiskitError("is_hermitian got an unexpected type.")
22 changes: 10 additions & 12 deletions qiskit_dynamics/solvers/diffrax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
Wrapper for diffrax solvers
"""

from typing import Callable, Optional, Union, Tuple, List
from typing import Callable, Optional
from scipy.integrate._ivp.ivp import OdeResult
from qiskit import QiskitError

from qiskit_dynamics.arraylias import ArrayLike
from qiskit_dynamics.dispatch import requires_backend
from qiskit_dynamics.array import Array, wrap

try:
import jax.numpy as jnp
Expand All @@ -33,10 +33,10 @@
@requires_backend("jax")
def diffrax_solver(
rhs: Callable,
t_span: Array,
y0: Array,
t_span: ArrayLike,
y0: ArrayLike,
method: "AbstractSolver",
t_eval: Optional[Union[Tuple, List, Array]] = None,
t_eval: Optional[ArrayLike] = None,
**kwargs,
):
"""Routine for calling ``diffrax.diffeqsolve``
Expand All @@ -57,15 +57,13 @@ def diffrax_solver(
"""

from diffrax import ODETerm, SaveAt
from diffrax import diffeqsolve as _diffeqsolve

diffeqsolve = wrap(_diffeqsolve)
from diffrax import diffeqsolve

# convert rhs and y0 to real
rhs = real_rhs(rhs)
y0 = c2r(y0)

term = ODETerm(lambda t, y, _: Array(rhs(t.real, y), dtype=float).data)
term = ODETerm(lambda t, y, _: rhs(t.real, y))

if "saveat" in kwargs and t_eval is not None:
raise QiskitError(
Expand All @@ -82,7 +80,7 @@ def diffrax_solver(
t0=t_span[0],
t1=t_span[-1],
dt0=None,
y0=Array(y0, dtype=float),
y0=jnp.array(y0, dtype=float),
**kwargs,
)

Expand All @@ -92,7 +90,7 @@ def diffrax_solver(

ys = jnp.swapaxes(r2c(jnp.swapaxes(ys, 0, 1)), 0, 1)

results_out = OdeResult(t=ts, y=Array(ys, backend="jax", dtype=complex), **sol_dict)
results_out = OdeResult(t=ts, y=jnp.array(ys, dtype=complex), **sol_dict)

return results_out

Expand All @@ -108,7 +106,7 @@ def _real_rhs(t, y):

def c2r(arr):
"""Convert complex array to a real array"""
return jnp.concatenate([jnp.real(Array(arr).data), jnp.imag(Array(arr).data)])
return jnp.concatenate([jnp.real(arr), jnp.imag(arr)])


def r2c(arr):
Expand Down
Loading