diff --git a/qiskit_dynamics/solvers/solver_functions.py b/qiskit_dynamics/solvers/solver_functions.py index 5143eb4cb..c28e6e333 100644 --- a/qiskit_dynamics/solvers/solver_functions.py +++ b/qiskit_dynamics/solvers/solver_functions.py @@ -25,7 +25,8 @@ from scipy.integrate._ivp.ivp import OdeResult # pylint: disable=unused-import from qiskit import QiskitError -from qiskit_dynamics.array import Array +from qiskit_dynamics import DYNAMICS_NUMPY as unp +from qiskit_dynamics.arraylias import ArrayLike from qiskit_dynamics.models import ( BaseGeneratorModel, @@ -95,11 +96,13 @@ def _is_diffrax_method(method: any) -> bool: def _lanczos_validation( rhs: Union[Callable, BaseGeneratorModel], - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, k_dim: int, ): """Validation checks to run lanczos based solvers.""" + t_span = unp.asarray(t_span) + y0 = unp.asarray(y0) if isinstance(rhs, BaseGeneratorModel): if not isinstance(rhs, HamiltonianModel): raise QiskitError( @@ -124,10 +127,10 @@ def _lanczos_validation( def solve_ode( rhs: Union[Callable, BaseGeneratorModel], - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, method: Optional[Union[str, OdeSolver, DiffraxAbstractSolver]] = "DOP853", - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, **kwargs, ): r"""General interface for solving Ordinary Differential Equations (ODEs). @@ -181,7 +184,7 @@ def solve_ode( ): raise QiskitError("Method " + str(method) + " not supported by solve_ode.") - y0 = Array(y0) + y0 = unp.asarray(y0) if isinstance(rhs, BaseGeneratorModel): _, solver_rhs, y0, model_in_frame_basis = setup_generator_model_rhs_y0_in_frame_basis( @@ -205,7 +208,7 @@ def solve_ode( # convert results out of frame basis if necessary if isinstance(rhs, BaseGeneratorModel): if not model_in_frame_basis: - results.y = results_y_out_of_frame_basis(rhs, Array(results.y), y0.ndim) + results.y = results_y_out_of_frame_basis(rhs, results.y, y0.ndim) # convert model back to original basis rhs.in_frame_basis = model_in_frame_basis @@ -215,10 +218,10 @@ def solve_ode( def solve_lmde( generator: Union[Callable, BaseGeneratorModel], - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, method: Optional[Union[str, OdeSolver, DiffraxAbstractSolver]] = "DOP853", - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, **kwargs, ): r"""General interface for solving Linear Matrix Differential Equations (LMDEs) @@ -333,7 +336,7 @@ def rhs(t, y): vectorized evaluation mode.""" ) - y0 = Array(y0) + y0 = unp.asarray(y0) # setup generator and rhs functions to pass to numerical methods if isinstance(generator, BaseGeneratorModel): @@ -363,7 +366,7 @@ def rhs(t, y): # convert results to correct basis if necessary if isinstance(generator, BaseGeneratorModel): if not model_in_frame_basis: - results.y = results_y_out_of_frame_basis(generator, Array(results.y), y0.ndim) + results.y = results_y_out_of_frame_basis(generator, results.y, y0.ndim) generator.in_frame_basis = model_in_frame_basis @@ -371,8 +374,8 @@ def rhs(t, y): def setup_generator_model_rhs_y0_in_frame_basis( - generator_model: BaseGeneratorModel, y0: Array -) -> Tuple[Callable, Callable, Array]: + generator_model: BaseGeneratorModel, y0: ArrayLike +) -> Tuple[Callable, Callable, ArrayLike]: """Helper function for setting up a subclass of :class:`~qiskit_dynamics.models.BaseGeneratorModel` to be solved in the frame basis. @@ -416,8 +419,8 @@ def rhs(t, y): def results_y_out_of_frame_basis( - generator_model: BaseGeneratorModel, results_y: Array, y0_ndim: int -) -> Array: + generator_model: BaseGeneratorModel, results_y: ArrayLike, y0_ndim: int +) -> ArrayLike: """Convert the results of a simulation for :class:`~qiskit_dynamics.models.BaseGeneratorModel` out of the frame basis. diff --git a/test/dynamics/solvers/test_solver_functions.py b/test/dynamics/solvers/test_solver_functions.py index 66d8cdcdb..06ae63981 100644 --- a/test/dynamics/solvers/test_solver_functions.py +++ b/test/dynamics/solvers/test_solver_functions.py @@ -27,9 +27,8 @@ from qiskit_dynamics.models import GeneratorModel, HamiltonianModel from qiskit_dynamics.signals import Signal, DiscreteSignal from qiskit_dynamics import solve_ode, solve_lmde -from qiskit_dynamics.array import Array -from ..common import QiskitDynamicsTestCase, DiffraxTestBase, TestJaxBase +from ..common import QiskitDynamicsTestCase, DiffraxTestBase, JAXTestBase try: from diffrax import PIDController, Tsit5, Dopri5 @@ -44,10 +43,10 @@ def setUp(self): """Construct standardized RHS functions and models.""" self.t_span = [0.0, 1.0] - self.y0 = Array(np.eye(2, dtype=complex)) + self.y0 = np.eye(2, dtype=complex) - self.X = Array([[0.0, 1.0], [1.0, 0.0]], dtype=complex) - self.Z = Array([[1.0, 0.0], [0.0, -1.0]], dtype=complex) + self.X = np.array([[0.0, 1.0], [1.0, 0.0]], dtype=complex) + self.Z = np.array([[1.0, 0.0], [0.0, -1.0]], dtype=complex) op = -1j * 2 * np.pi * self.X / 2 @@ -129,10 +128,10 @@ def test_ode_method(self): if self.is_ode_method: # pylint: disable=unused-argument def quad_rhs(t, y): - return np.real(Array([t**2])) + return np.real(np.array([t**2])) - results = self.solve(quad_rhs, t_span=[0.0, 1.0], y0=Array([0.0])) - expected = Array([1.0 / 3]) + results = self.solve(quad_rhs, t_span=[0.0, 1.0], y0=np.array([0.0])) + expected = np.array([1.0 / 3]) self.assertAllClose(results.y[-1], expected) def test_basic_model_lmde_from_ode(self): @@ -143,7 +142,7 @@ def test_basic_model_lmde_from_ode(self): self.basic_rhs, t_span=self.t_span, y0=self.y0, solver_func=solve_lmde ) - expected = expm(-1j * np.pi * self.X.data) + expected = expm(-1j * np.pi * self.X) self.assertAllClose(results.y[-1], expected, atol=self.tol, rtol=self.tol) @@ -152,7 +151,7 @@ def test_basic_model(self): results = self.solve(self.basic_rhs, t_span=self.t_span, y0=self.y0) - expected = expm(-1j * np.pi * self.X.data) + expected = expm(-1j * np.pi * self.X) self.assertAllClose(results.y[-1], expected, atol=self.tol, rtol=self.tol) @@ -162,7 +161,7 @@ def test_backwards_solving(self): reverse_t_span = self.t_span.copy() reverse_t_span.reverse() - reverse_y0 = expm(-1j * np.pi * self.X.data) + reverse_y0 = expm(-1j * np.pi * self.X) results = self.solve(self.basic_rhs, t_span=reverse_t_span, y0=reverse_y0) @@ -173,7 +172,7 @@ def test_w_GeneratorModel(self): results = self.solve( self.simple_model, - y0=Array([0.0, 1.0], dtype=complex), + y0=np.array([0.0, 1.0], dtype=complex), t_span=[0, 1 / self.r], ) yf = results.y[-1] @@ -210,7 +209,7 @@ def test_pseudo_random_model(self): self.assertTrue(self.pseudo_random_model.in_frame_basis) -class TestSolverMethodJax(TestSolverMethod, TestJaxBase): +class TestSolverMethodJax(TestSolverMethod, JAXTestBase): """JAX version of TestSolverMethod. Adds additional jit/grad test.""" def test_pseudo_random_jit_grad(self): @@ -349,8 +348,8 @@ def setUp(self): evaluation_mode="sparse", ) - self.operators = self.pseudo_random_model.operators.data - self.static_operator = self.pseudo_random_model.static_operator.data + self.operators = self.pseudo_random_model.operators + self.static_operator = self.pseudo_random_model.static_operator # make hermitian self.operators = self.operators.conj().transpose(0, 2, 1) + self.operators