Skip to content

Commit

Permalink
partial progress on solver_functions, but blocked by models
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli committed Oct 30, 2023
1 parent bee2878 commit 8f6e46f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 32 deletions.
37 changes: 20 additions & 17 deletions qiskit_dynamics/solvers/solver_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from scipy.integrate._ivp.ivp import OdeResult # pylint: disable=unused-import

from qiskit import QiskitError
from qiskit_dynamics.array import Array
from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics.arraylias import ArrayLike

from qiskit_dynamics.models import (
BaseGeneratorModel,
Expand Down Expand Up @@ -95,11 +96,13 @@ def _is_diffrax_method(method: any) -> bool:

def _lanczos_validation(
rhs: Union[Callable, BaseGeneratorModel],
t_span: Array,
y0: Array,
t_span: ArrayLike,
y0: ArrayLike,
k_dim: int,
):
"""Validation checks to run lanczos based solvers."""
t_span = unp.asarray(t_span)
y0 = unp.asarray(y0)
if isinstance(rhs, BaseGeneratorModel):
if not isinstance(rhs, HamiltonianModel):
raise QiskitError(
Expand All @@ -124,10 +127,10 @@ def _lanczos_validation(

def solve_ode(
rhs: Union[Callable, BaseGeneratorModel],
t_span: Array,
y0: Array,
t_span: ArrayLike,
y0: ArrayLike,
method: Optional[Union[str, OdeSolver, DiffraxAbstractSolver]] = "DOP853",
t_eval: Optional[Union[Tuple, List, Array]] = None,
t_eval: Optional[ArrayLike] = None,
**kwargs,
):
r"""General interface for solving Ordinary Differential Equations (ODEs).
Expand Down Expand Up @@ -181,7 +184,7 @@ def solve_ode(
):
raise QiskitError("Method " + str(method) + " not supported by solve_ode.")

y0 = Array(y0)
y0 = unp.asarray(y0)

if isinstance(rhs, BaseGeneratorModel):
_, solver_rhs, y0, model_in_frame_basis = setup_generator_model_rhs_y0_in_frame_basis(
Expand All @@ -205,7 +208,7 @@ def solve_ode(
# convert results out of frame basis if necessary
if isinstance(rhs, BaseGeneratorModel):
if not model_in_frame_basis:
results.y = results_y_out_of_frame_basis(rhs, Array(results.y), y0.ndim)
results.y = results_y_out_of_frame_basis(rhs, results.y, y0.ndim)

# convert model back to original basis
rhs.in_frame_basis = model_in_frame_basis
Expand All @@ -215,10 +218,10 @@ def solve_ode(

def solve_lmde(
generator: Union[Callable, BaseGeneratorModel],
t_span: Array,
y0: Array,
t_span: ArrayLike,
y0: ArrayLike,
method: Optional[Union[str, OdeSolver, DiffraxAbstractSolver]] = "DOP853",
t_eval: Optional[Union[Tuple, List, Array]] = None,
t_eval: Optional[ArrayLike] = None,
**kwargs,
):
r"""General interface for solving Linear Matrix Differential Equations (LMDEs)
Expand Down Expand Up @@ -333,7 +336,7 @@ def rhs(t, y):
vectorized evaluation mode."""
)

y0 = Array(y0)
y0 = unp.asarray(y0)

# setup generator and rhs functions to pass to numerical methods
if isinstance(generator, BaseGeneratorModel):
Expand Down Expand Up @@ -363,16 +366,16 @@ def rhs(t, y):
# convert results to correct basis if necessary
if isinstance(generator, BaseGeneratorModel):
if not model_in_frame_basis:
results.y = results_y_out_of_frame_basis(generator, Array(results.y), y0.ndim)
results.y = results_y_out_of_frame_basis(generator, results.y, y0.ndim)

generator.in_frame_basis = model_in_frame_basis

return results


def setup_generator_model_rhs_y0_in_frame_basis(
generator_model: BaseGeneratorModel, y0: Array
) -> Tuple[Callable, Callable, Array]:
generator_model: BaseGeneratorModel, y0: ArrayLike
) -> Tuple[Callable, Callable, ArrayLike]:
"""Helper function for setting up a subclass of
:class:`~qiskit_dynamics.models.BaseGeneratorModel` to be solved in the frame basis.
Expand Down Expand Up @@ -416,8 +419,8 @@ def rhs(t, y):


def results_y_out_of_frame_basis(
generator_model: BaseGeneratorModel, results_y: Array, y0_ndim: int
) -> Array:
generator_model: BaseGeneratorModel, results_y: ArrayLike, y0_ndim: int
) -> ArrayLike:
"""Convert the results of a simulation for :class:`~qiskit_dynamics.models.BaseGeneratorModel`
out of the frame basis.
Expand Down
29 changes: 14 additions & 15 deletions test/dynamics/solvers/test_solver_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
from qiskit_dynamics.models import GeneratorModel, HamiltonianModel
from qiskit_dynamics.signals import Signal, DiscreteSignal
from qiskit_dynamics import solve_ode, solve_lmde
from qiskit_dynamics.array import Array

from ..common import QiskitDynamicsTestCase, DiffraxTestBase, TestJaxBase
from ..common import QiskitDynamicsTestCase, DiffraxTestBase, JAXTestBase

try:
from diffrax import PIDController, Tsit5, Dopri5
Expand All @@ -44,10 +43,10 @@ def setUp(self):
"""Construct standardized RHS functions and models."""

self.t_span = [0.0, 1.0]
self.y0 = Array(np.eye(2, dtype=complex))
self.y0 = np.eye(2, dtype=complex)

self.X = Array([[0.0, 1.0], [1.0, 0.0]], dtype=complex)
self.Z = Array([[1.0, 0.0], [0.0, -1.0]], dtype=complex)
self.X = np.array([[0.0, 1.0], [1.0, 0.0]], dtype=complex)
self.Z = np.array([[1.0, 0.0], [0.0, -1.0]], dtype=complex)

op = -1j * 2 * np.pi * self.X / 2

Expand Down Expand Up @@ -129,10 +128,10 @@ def test_ode_method(self):
if self.is_ode_method:
# pylint: disable=unused-argument
def quad_rhs(t, y):
return np.real(Array([t**2]))
return np.real(np.array([t**2]))

results = self.solve(quad_rhs, t_span=[0.0, 1.0], y0=Array([0.0]))
expected = Array([1.0 / 3])
results = self.solve(quad_rhs, t_span=[0.0, 1.0], y0=np.array([0.0]))
expected = np.array([1.0 / 3])
self.assertAllClose(results.y[-1], expected)

def test_basic_model_lmde_from_ode(self):
Expand All @@ -143,7 +142,7 @@ def test_basic_model_lmde_from_ode(self):
self.basic_rhs, t_span=self.t_span, y0=self.y0, solver_func=solve_lmde
)

expected = expm(-1j * np.pi * self.X.data)
expected = expm(-1j * np.pi * self.X)

self.assertAllClose(results.y[-1], expected, atol=self.tol, rtol=self.tol)

Expand All @@ -152,7 +151,7 @@ def test_basic_model(self):

results = self.solve(self.basic_rhs, t_span=self.t_span, y0=self.y0)

expected = expm(-1j * np.pi * self.X.data)
expected = expm(-1j * np.pi * self.X)

self.assertAllClose(results.y[-1], expected, atol=self.tol, rtol=self.tol)

Expand All @@ -162,7 +161,7 @@ def test_backwards_solving(self):
reverse_t_span = self.t_span.copy()
reverse_t_span.reverse()

reverse_y0 = expm(-1j * np.pi * self.X.data)
reverse_y0 = expm(-1j * np.pi * self.X)

results = self.solve(self.basic_rhs, t_span=reverse_t_span, y0=reverse_y0)

Expand All @@ -173,7 +172,7 @@ def test_w_GeneratorModel(self):

results = self.solve(
self.simple_model,
y0=Array([0.0, 1.0], dtype=complex),
y0=np.array([0.0, 1.0], dtype=complex),
t_span=[0, 1 / self.r],
)
yf = results.y[-1]
Expand Down Expand Up @@ -210,7 +209,7 @@ def test_pseudo_random_model(self):
self.assertTrue(self.pseudo_random_model.in_frame_basis)


class TestSolverMethodJax(TestSolverMethod, TestJaxBase):
class TestSolverMethodJax(TestSolverMethod, JAXTestBase):
"""JAX version of TestSolverMethod. Adds additional jit/grad test."""

def test_pseudo_random_jit_grad(self):
Expand Down Expand Up @@ -349,8 +348,8 @@ def setUp(self):
evaluation_mode="sparse",
)

self.operators = self.pseudo_random_model.operators.data
self.static_operator = self.pseudo_random_model.static_operator.data
self.operators = self.pseudo_random_model.operators
self.static_operator = self.pseudo_random_model.static_operator

# make hermitian
self.operators = self.operators.conj().transpose(0, 2, 1) + self.operators
Expand Down

0 comments on commit 8f6e46f

Please sign in to comment.