diff --git a/qiskit_dynamics/models/hamiltonian_model.py b/qiskit_dynamics/models/hamiltonian_model.py index e47a0fac8..998f8c037 100644 --- a/qiskit_dynamics/models/hamiltonian_model.py +++ b/qiskit_dynamics/models/hamiltonian_model.py @@ -172,7 +172,7 @@ def is_hermitian(operator: ArrayLike, tol: Optional[float] = 1e-10) -> bool: return is_hermitian(operator.todense()) elif isinstance(operator, ArrayLike): adj = None - adj = np.transpose(np.conjugate(operator)) + adj = unp.transpose(unp.conjugate(operator)) return np.linalg.norm(adj - operator) < tol raise QiskitError("is_hermitian got an unexpected type.") diff --git a/qiskit_dynamics/solvers/diffrax_solver.py b/qiskit_dynamics/solvers/diffrax_solver.py index 4dc9f2efe..2aad56abd 100644 --- a/qiskit_dynamics/solvers/diffrax_solver.py +++ b/qiskit_dynamics/solvers/diffrax_solver.py @@ -17,12 +17,12 @@ Wrapper for diffrax solvers """ -from typing import Callable, Optional, Union, Tuple, List +from typing import Callable, Optional from scipy.integrate._ivp.ivp import OdeResult from qiskit import QiskitError +from qiskit_dynamics.arraylias import ArrayLike from qiskit_dynamics.dispatch import requires_backend -from qiskit_dynamics.array import Array, wrap try: import jax.numpy as jnp @@ -33,10 +33,10 @@ @requires_backend("jax") def diffrax_solver( rhs: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, method: "AbstractSolver", - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, **kwargs, ): """Routine for calling ``diffrax.diffeqsolve`` @@ -57,15 +57,13 @@ def diffrax_solver( """ from diffrax import ODETerm, SaveAt - from diffrax import diffeqsolve as _diffeqsolve - - diffeqsolve = wrap(_diffeqsolve) + from diffrax import diffeqsolve # convert rhs and y0 to real rhs = real_rhs(rhs) y0 = c2r(y0) - term = ODETerm(lambda t, y, _: Array(rhs(t.real, y), dtype=float).data) + term = ODETerm(lambda t, y, _: rhs(t.real, y)) if "saveat" in kwargs and t_eval is not None: raise QiskitError( @@ -82,7 +80,7 @@ def diffrax_solver( t0=t_span[0], t1=t_span[-1], dt0=None, - y0=Array(y0, dtype=float), + y0=jnp.array(y0, dtype=float), **kwargs, ) @@ -92,7 +90,7 @@ def diffrax_solver( ys = jnp.swapaxes(r2c(jnp.swapaxes(ys, 0, 1)), 0, 1) - results_out = OdeResult(t=ts, y=Array(ys, backend="jax", dtype=complex), **sol_dict) + results_out = OdeResult(t=ts, y=jnp.array(ys, dtype=complex), **sol_dict) return results_out @@ -108,7 +106,7 @@ def _real_rhs(t, y): def c2r(arr): """Convert complex array to a real array""" - return jnp.concatenate([jnp.real(Array(arr).data), jnp.imag(Array(arr).data)]) + return jnp.concatenate([jnp.real(arr), jnp.imag(arr)]) def r2c(arr): diff --git a/qiskit_dynamics/solvers/fixed_step_solvers.py b/qiskit_dynamics/solvers/fixed_step_solvers.py index b05c3e96b..6d6e15c05 100644 --- a/qiskit_dynamics/solvers/fixed_step_solvers.py +++ b/qiskit_dynamics/solvers/fixed_step_solvers.py @@ -15,7 +15,7 @@ Custom fixed step solvers. """ -from typing import Callable, Optional, Union, Tuple, List +from typing import Callable, Optional, Tuple from warnings import warn import numpy as np from scipy.integrate._ivp.ivp import OdeResult @@ -23,8 +23,9 @@ from qiskit import QiskitError +from qiskit_dynamics import DYNAMICS_NUMPY as unp from qiskit_dynamics.dispatch import requires_backend -from qiskit_dynamics.array import Array, wrap +from qiskit_dynamics.arraylias import ArrayLike try: import jax @@ -37,17 +38,15 @@ from .solver_utils import merge_t_args, trim_t_results from .lanczos import lanczos_expm -from .lanczos import jax_lanczos_expm as jax_lanczos_expm_ - -jax_lanczos_expm = wrap(jax_lanczos_expm_) +from .lanczos import jax_lanczos_expm def RK4_solver( rhs: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, ): """Fixed step RK4 solver. @@ -74,21 +73,17 @@ def take_step(rhs_func, t, y, h): return y + div6 * h * (k1 + 2 * k2 + 2 * k3 + k4) - # ensure the output of rhs_func is a raw array - def wrapped_rhs_func(*args): - return Array(rhs(*args)).data - return fixed_step_solver_template( - take_step, rhs_func=wrapped_rhs_func, t_span=t_span, y0=y0, max_dt=max_dt, t_eval=t_eval + take_step, rhs_func=rhs, t_span=t_span, y0=y0, max_dt=max_dt, t_eval=t_eval ) def scipy_expm_solver( generator: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, magnus_order: int = 1, ): """Fixed-step size matrix exponential based solver implemented with @@ -109,22 +104,18 @@ def scipy_expm_solver( """ take_step = get_exponential_take_step(magnus_order, expm_func=expm) - # ensure the output of rhs_func is a raw array - def wrapped_rhs_func(*args): - return Array(generator(*args)).data - return fixed_step_solver_template( - take_step, rhs_func=wrapped_rhs_func, t_span=t_span, y0=y0, max_dt=max_dt, t_eval=t_eval + take_step, rhs_func=generator, t_span=t_span, y0=y0, max_dt=max_dt, t_eval=t_eval ) def lanczos_diag_solver( generator: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, k_dim: int, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, ): """Fixed-step size matrix exponential based solver implemented using lanczos algorithm. Solves the specified problem by taking steps of @@ -156,19 +147,19 @@ def take_step(generator, t0, y, h): @requires_backend("jax") def jax_lanczos_diag_solver( generator: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, k_dim: int, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, ): """JAX version of lanczos_diag_solver.""" def take_step(generator, t0, y, h): eval_time = t0 + (h / 2) - return jax_lanczos_expm(generator(eval_time), y, k_dim, h).data + return jax_lanczos_expm(generator(eval_time), y, k_dim, h) - y0 = Array(y0, dtype=complex) + y0 = unp.asarray(y0, dtype=complex) return fixed_step_solver_template_jax( take_step, rhs_func=generator, t_span=t_span, y0=y0, max_dt=max_dt, t_eval=t_eval @@ -178,10 +169,10 @@ def take_step(generator, t0, y, h): @requires_backend("jax") def jax_RK4_solver( rhs: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, ): """JAX version of RK4_solver. @@ -208,21 +199,18 @@ def take_step(rhs_func, t, y, h): return y + div6 * h * (k1 + 2 * k2 + 2 * k3 + k4) - def wrapped_rhs_func(*args): - return Array(rhs(*args), backend="jax").data - return fixed_step_solver_template_jax( - take_step, rhs_func=wrapped_rhs_func, t_span=t_span, y0=y0, max_dt=max_dt, t_eval=t_eval + take_step, rhs_func=rhs, t_span=t_span, y0=y0, max_dt=max_dt, t_eval=t_eval ) @requires_backend("jax") def jax_RK4_parallel_solver( generator: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, ): """Parallel version of :func:`jax_RK4_solver` specialized to LMDEs. @@ -260,10 +248,10 @@ def take_step(generator, t, h): @requires_backend("jax") def jax_expm_solver( generator: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, magnus_order: int = 1, ): """Fixed-step size matrix exponential based solver implemented with ``jax``. @@ -283,21 +271,18 @@ def jax_expm_solver( """ take_step = get_exponential_take_step(magnus_order, expm_func=jexpm) - def wrapped_rhs_func(*args): - return Array(generator(*args), backend="jax").data - return fixed_step_solver_template_jax( - take_step, rhs_func=wrapped_rhs_func, t_span=t_span, y0=y0, max_dt=max_dt, t_eval=t_eval + take_step, rhs_func=generator, t_span=t_span, y0=y0, max_dt=max_dt, t_eval=t_eval ) @requires_backend("jax") def jax_expm_parallel_solver( generator: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, magnus_order: int = 1, ): """Parallel version of :func:`jax_expm_solver` implemented with JAX parallel operations. @@ -321,7 +306,7 @@ def jax_expm_parallel_solver( ) -def matrix_commutator(m1: Array, m2: Array) -> Array: +def matrix_commutator(m1: ArrayLike, m2: ArrayLike) -> ArrayLike: """Compute the commutator of two matrices. Args: @@ -422,10 +407,10 @@ def take_step(generator, t0, y, h): def fixed_step_solver_template( take_step: Callable, rhs_func: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, ): """Helper function for implementing fixed-step solvers supporting both ``t_span`` and ``max_dt`` arguments. ``take_step`` is assumed to be a @@ -456,7 +441,7 @@ def fixed_step_solver_template( OdeResult: Results object. """ - y0 = Array(y0).data + y0 = unp.asarray(y0) t_list, h_list, n_steps_list = get_fixed_step_sizes(t_span, t_eval, max_dt) @@ -468,7 +453,7 @@ def fixed_step_solver_template( y = take_step(rhs_func, inner_t, y, h) inner_t = inner_t + h ys.append(y) - ys = Array(ys) + ys = unp.asarray(ys) results = OdeResult(t=t_list, y=ys) @@ -478,10 +463,10 @@ def fixed_step_solver_template( def fixed_step_solver_template_jax( take_step: Callable, rhs_func: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, ): """This function is the jax control-flow version of :meth:`fixed_step_solver_template`. See the documentation of :meth:`fixed_step_solver_template` @@ -499,7 +484,7 @@ def fixed_step_solver_template_jax( OdeResult: Results object. """ - y0 = Array(y0).data + y0 = jnp.array(y0) t_list, h_list, n_steps_list = get_fixed_step_sizes(t_span, t_eval, max_dt) @@ -530,7 +515,7 @@ def scan_take_step(carry, step): xs=(jnp.array(t_list[:-1]), jnp.array(h_list), jnp.array(n_steps_list)), )[1] - ys = Array(jnp.append(jnp.expand_dims(y0, axis=0), ys, axis=0), backend="jax") + ys = jnp.append(jnp.expand_dims(y0, axis=0), ys, axis=0) results = OdeResult(t=t_list, y=ys) @@ -540,10 +525,10 @@ def scan_take_step(carry, step): def fixed_step_lmde_solver_parallel_template_jax( take_step: Callable, generator: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, max_dt: float, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, ): """Parallelized and LMDE specific version of fixed_step_solver_template_jax. @@ -588,11 +573,7 @@ def fixed_step_lmde_solver_parallel_template_jax( stacklevel=2, ) - # ensure the output of rhs_func is a raw array - def wrapped_generator(*args): - return Array(generator(*args), backend="jax").data - - y0 = Array(y0).data + y0 = jnp.array(y0) t_list, h_list, n_steps_list = get_fixed_step_sizes(t_span, t_eval, max_dt) @@ -606,7 +587,7 @@ def wrapped_generator(*args): t_list_locations = np.append(t_list_locations, [t_list_locations[-1] + n_steps]) # compute propagators over each time step in parallel - step_propagators = vmap(lambda t, h: take_step(wrapped_generator, t, h))(all_times, all_h) + step_propagators = vmap(lambda t, h: take_step(generator, t, h))(all_times, all_h) # multiply propagators together in parallel ys = None @@ -625,12 +606,14 @@ def wrapped_generator(*args): intermediate_y = intermediate_props[t_list_locations[1:] - 1] @ y0 ys = jnp.append(jnp.array([y0]), intermediate_y, axis=0) - results = OdeResult(t=t_list, y=Array(ys, backend="jax")) + results = OdeResult(t=t_list, y=ys) return trim_t_results(results, t_eval) -def get_fixed_step_sizes(t_span: Array, t_eval: Array, max_dt: float) -> Tuple[Array, Array, Array]: +def get_fixed_step_sizes( + t_span: ArrayLike, t_eval: ArrayLike, max_dt: float +) -> Tuple[ArrayLike, ArrayLike, ArrayLike]: """Merge ``t_span`` and ``t_eval``, and determine the number of time steps and and step sizes (no larger than ``max_dt``) required to fixed-step integrate between each time point. @@ -645,8 +628,8 @@ def get_fixed_step_sizes(t_span: Array, t_eval: Array, max_dt: float) -> Tuple[A between time points, and list of corresponding number of steps to take between time steps. """ # time args are non-differentiable - t_span = Array(t_span, backend="numpy").data - max_dt = Array(max_dt, backend="numpy").data + t_span = np.array(t_span) + max_dt = np.array(max_dt) t_list = np.array(merge_t_args(t_span, t_eval)) # set the number of time steps required in each interval so that diff --git a/qiskit_dynamics/solvers/jax_odeint.py b/qiskit_dynamics/solvers/jax_odeint.py index 9efacf2d1..f5982527e 100644 --- a/qiskit_dynamics/solvers/jax_odeint.py +++ b/qiskit_dynamics/solvers/jax_odeint.py @@ -17,19 +17,17 @@ Wrapper for jax.experimental.ode.odeint """ -from typing import Callable, Optional, Union, Tuple, List -import numpy as np +from typing import Callable, Optional from scipy.integrate._ivp.ivp import OdeResult +from qiskit_dynamics import DYNAMICS_NUMPY as unp +from qiskit_dynamics.arraylias import ArrayLike from qiskit_dynamics.dispatch import requires_backend -from qiskit_dynamics.array import Array, wrap from .solver_utils import merge_t_args_jax, trim_t_results_jax try: - from jax.experimental.ode import odeint as _odeint - - odeint = wrap(_odeint) + from jax.experimental.ode import odeint except ImportError: pass @@ -37,9 +35,9 @@ @requires_backend("jax") def jax_odeint( rhs: Callable, - t_span: Array, - y0: Array, - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_span: ArrayLike, + y0: ArrayLike, + t_eval: Optional[ArrayLike] = None, **kwargs, ): """Routine for calling `jax.experimental.ode.odeint` @@ -58,16 +56,15 @@ def jax_odeint( t_list = merge_t_args_jax(t_span, t_eval) # determine direction of integration - t_direction = np.sign(Array(t_list[-1] - t_list[0], backend="jax", dtype=complex)) - rhs = wrap(rhs) + t_direction = unp.sign(unp.asarray(t_list[-1] - t_list[0], dtype=complex)) results = odeint( - lambda y, t: rhs(np.real(t_direction * t), y) * t_direction, - y0=Array(y0, dtype=complex), - t=np.real(t_direction) * Array(t_list), + lambda y, t: rhs(unp.real(t_direction * t), y) * t_direction, + y0=unp.asarray(y0, dtype=complex), + t=unp.real(t_direction) * unp.asarray(t_list), **kwargs, ) - results = OdeResult(t=t_list, y=Array(results, backend="jax", dtype=complex)) + results = OdeResult(t=t_list, y=results) return trim_t_results_jax(results, t_eval) diff --git a/qiskit_dynamics/solvers/lanczos.py b/qiskit_dynamics/solvers/lanczos.py index f10364895..4221640bd 100644 --- a/qiskit_dynamics/solvers/lanczos.py +++ b/qiskit_dynamics/solvers/lanczos.py @@ -20,7 +20,6 @@ from scipy.sparse import csr_matrix from qiskit_dynamics.dispatch import requires_backend -from qiskit_dynamics.array import Array try: import jax.numpy as jnp @@ -148,7 +147,7 @@ def lanczos_expm( @requires_backend("jax") -def jax_lanczos_basis(A: Array, y0: Array, k_dim: int): +def jax_lanczos_basis(A: jnp.ndarray, y0: jnp.ndarray, k_dim: int): """JAX version of lanczos_basis.""" data_type = jnp.result_type(A.dtype, y0.dtype) @@ -204,7 +203,7 @@ def cond_func(qpb, _): @requires_backend("jax") -def jax_lanczos_eigh(A: Array, y0: Array, k_dim: int): +def jax_lanczos_eigh(A: jnp.ndarray, y0: jnp.ndarray, k_dim: int): """JAX version of lanczos_eigh.""" tridiagonal, q_basis = jax_lanczos_basis(A, y0, k_dim) @@ -215,8 +214,8 @@ def jax_lanczos_eigh(A: Array, y0: Array, k_dim: int): @requires_backend("jax") def jax_lanczos_expm( - A: Array, - y0: Array, + A: jnp.ndarray, + y0: jnp.ndarray, k_dim: int, scale_factor: Optional[float] = 1, ): diff --git a/qiskit_dynamics/solvers/scipy_solve_ivp.py b/qiskit_dynamics/solvers/scipy_solve_ivp.py index 656f726d1..b68b7d315 100644 --- a/qiskit_dynamics/solvers/scipy_solve_ivp.py +++ b/qiskit_dynamics/solvers/scipy_solve_ivp.py @@ -15,14 +15,14 @@ Wrapper for calling scipy.integrate.solve_ivp. """ -from typing import Callable, Union, Optional, Tuple, List +from typing import Callable, Union, Optional import numpy as np from scipy.integrate import solve_ivp, OdeSolver from scipy.integrate._ivp.ivp import OdeResult from qiskit import QiskitError -from qiskit_dynamics.array import Array +from qiskit_dynamics.arraylias import ArrayLike from ..type_utils import StateTypeConverter # Supported scipy ODE methods @@ -33,10 +33,10 @@ def scipy_solve_ivp( rhs: Callable, - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, method: Union[str, OdeSolver], - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, **kwargs, ): """Routine for calling `scipy.integrate.solve_ivp`. @@ -84,7 +84,7 @@ def scipy_solve_ivp( # convert to the standardized results format # solve_ivp returns the states as a 2d array with columns being the states results.y = results.y.transpose() - results.y = Array([type_converter.inner_to_outer(y) for y in results.y]) + results.y = np.array([type_converter.inner_to_outer(y) for y in results.y]) return OdeResult(**dict(results)) diff --git a/qiskit_dynamics/solvers/solver_classes.py b/qiskit_dynamics/solvers/solver_classes.py index be7ed47ba..084ac9a58 100644 --- a/qiskit_dynamics/solvers/solver_classes.py +++ b/qiskit_dynamics/solvers/solver_classes.py @@ -19,10 +19,11 @@ from typing import Optional, Union, Tuple, Any, Type, List, Callable +from warnings import warn import numpy as np -from scipy.integrate._ivp.ivp import OdeResult # pylint: disable=unused-import +from scipy.integrate._ivp.ivp import OdeResult from qiskit import QiskitError from qiskit.pulse import Schedule, ScheduleBlock @@ -34,6 +35,10 @@ from qiskit.quantum_info.states.quantum_state import QuantumState from qiskit.quantum_info import SuperOp, Operator, DensityMatrix +from qiskit_dynamics import ArrayLike +from qiskit_dynamics import DYNAMICS_NUMPY as unp +from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS as numpy_alias + from qiskit_dynamics.models import ( HamiltonianModel, LindbladModel, @@ -42,8 +47,6 @@ ) from qiskit_dynamics.signals import Signal, DiscreteSignal, SignalList from qiskit_dynamics.pulse import InstructionToSignals -from qiskit_dynamics.array import Array -from qiskit_dynamics.dispatch.dispatch import Dispatch from .solver_functions import solve_lmde, _is_diffrax_method from .solver_utils import ( @@ -169,9 +172,9 @@ class Solver: frequencies will be used for the RWA. * ``dt``: The envelope sample width. - If configured to simulate Pulse schedules while ``Array.default_backend() == 'jax'``, - calling :meth:`.Solver.solve` will automatically compile - simulation runs when calling with a JAX-based solver method. + If configured to simulate Pulse schedules, and a JAX-based solver method is chosen when calling + :meth:`.Solver.solve`, :meth:`.Solver.solve` will automatically attempt to compile a single + function to re-use for all schedule simulations. The evolution given by the model can be simulated by calling :meth:`.Solver.solve`, which calls :func:`.solve_lmde`, and does various automatic @@ -180,19 +183,20 @@ class Solver: def __init__( self, - static_hamiltonian: Optional[Array] = None, - hamiltonian_operators: Optional[Array] = None, - static_dissipators: Optional[Array] = None, - dissipator_operators: Optional[Array] = None, + static_hamiltonian: Optional[ArrayLike] = None, + hamiltonian_operators: Optional[ArrayLike] = None, + static_dissipators: Optional[ArrayLike] = None, + dissipator_operators: Optional[ArrayLike] = None, hamiltonian_channels: Optional[List[str]] = None, dissipator_channels: Optional[List[str]] = None, channel_carrier_freqs: Optional[dict] = None, dt: Optional[float] = None, - rotating_frame: Optional[Union[Array, RotatingFrame]] = None, + rotating_frame: Optional[Union[ArrayLike, RotatingFrame]] = None, in_frame_basis: bool = False, - evaluation_mode: str = "dense", + array_library: Optional[str] = None, + vectorized: Optional[bool] = None, rwa_cutoff_freq: Optional[float] = None, - rwa_carrier_freqs: Optional[Union[Array, Tuple[Array, Array]]] = None, + rwa_carrier_freqs: Optional[Union[ArrayLike, Tuple[ArrayLike, ArrayLike]]] = None, validate: bool = True, ): """Initialize solver with model information. @@ -218,10 +222,9 @@ def __init__( in_frame_basis: Whether to represent the model in the basis in which the rotating frame operator is diagonalized. See class documentation for a more detailed explanation on how this argument affects object behaviour. - evaluation_mode: Method for model evaluation. See documentation for - ``HamiltonianModel.evaluation_mode`` or - ``LindbladModel.evaluation_mode``. - (if dissipators in model) for valid modes. + array_library: Array library to use for storing operators of underlying model. + vectorized: If including dissipator terms, whether or not to construct the + :class:`.LindbladModel` in vectorized form. rwa_cutoff_freq: Rotating wave approximation cutoff frequency. If ``None``, no approximation is made. rwa_carrier_freqs: Carrier frequencies to use for rotating wave approximation. @@ -312,7 +315,7 @@ def __init__( operators=hamiltonian_operators, rotating_frame=rotating_frame, in_frame_basis=in_frame_basis, - evaluation_mode=evaluation_mode, + array_library=array_library, validate=validate, ) else: @@ -323,7 +326,8 @@ def __init__( dissipator_operators=dissipator_operators, rotating_frame=rotating_frame, in_frame_basis=in_frame_basis, - evaluation_mode=evaluation_mode, + array_library=array_library, + vectorized=vectorized, validate=validate, ) @@ -385,8 +389,8 @@ def model(self) -> Union[HamiltonianModel, LindbladModel]: def solve( self, - t_span: Array, - y0: Union[Array, QuantumState, BaseOperator], + t_span: ArrayLike, + y0: Union[ArrayLike, QuantumState, BaseOperator], signals: Optional[ Union[ List[Union[Schedule, ScheduleBlock]], @@ -440,7 +444,7 @@ def solve( - Model type - ``yf`` type - Description - * - ``Array``, ``np.ndarray``, ``Operator`` + * - ``ArrayLike``, ``np.ndarray``, ``Operator`` - Any - Same as ``y0`` - Solves either the Schrodinger equation or Lindblad equation @@ -468,8 +472,8 @@ def solve( * - ``QuantumChannel`` - ``LindbladModel`` - ``SuperOp`` - - Solves the vectorized Lindblad equation with initial state ``y0``. - ``evaluation_mode`` must be set to a vectorized option. + - Solves the vectorized Lindblad equation with initial state ``y0``. ``vectorized`` + must be set to ``True``. In some cases (e.g. if using JAX), wrapping the returned states in the type given in the ``yf`` type column above may be undesirable. Setting @@ -520,12 +524,19 @@ def solve( all_results = None method = kwargs.get("method", "") if ( - Array.default_backend() == "jax" - and (method == "jax_odeint" or _is_diffrax_method(method)) + (method == "jax_odeint" or _is_diffrax_method(method)) and all(isinstance(x, Schedule) for x in signals_list) # check if jit transformation is already performed. and not (isinstance(jnp.array(0), core.Tracer)) ): + if self.model.array_library not in ["numpy", "jax", "jax_sparse"]: + warn( + "Attempting to internally JAX-compile simulation of schedules, with " + 'Solver.model.array_library not in ["numpy", "jax", "jax_sparse"]. If an error ' + "is not raised, explicitly set array_library at Solver instantation to one of " + "these options to remove this warning." + ) + all_results = self._solve_schedule_list_jax( t_span_list=t_span_list, y0_list=y0_list, @@ -552,8 +563,8 @@ def solve( def _solve_list( self, - t_span_list: List[Array], - y0_list: List[Union[Array, QuantumState, BaseOperator]], + t_span_list: List[ArrayLike], + y0_list: List[Union[ArrayLike, QuantumState, BaseOperator]], signals_list: Optional[ Union[List[Schedule], List[List[Signal]], List[Tuple[List[Signal], List[Signal]]]] ] = None, @@ -588,8 +599,8 @@ def _solve_list( def _solve_schedule_list_jax( self, - t_span_list: List[Array], - y0_list: List[Union[Array, QuantumState, BaseOperator]], + t_span_list: List[ArrayLike], + y0_list: List[Union[ArrayLike, QuantumState, BaseOperator]], schedule_list: List[Schedule], convert_results: bool = True, **kwargs, @@ -637,7 +648,7 @@ def sim_function(t_span, y0, all_samples, y0_input, y0_cls): # reset signals to ensure purity self.model.signals = model_sigs - return Array(results.t).data, Array(results.y).data + return results.t, results.y jit_sim_function = jit(sim_function, static_argnums=(4,)) @@ -657,9 +668,13 @@ def sim_function(t_span, y0, all_samples, y0_input, y0_cls): all_samples[idx, 0 : len(sig.samples)] = np.array(sig.samples) results_t, results_y = jit_sim_function( - Array(t_span).data, Array(y0).data, all_samples, Array(y0_input).data, y0_cls + unp.asarray(t_span), + unp.asarray(y0), + unp.asarray(all_samples), + unp.asarray(y0_input), + y0_cls, ) - results = OdeResult(t=results_t, y=Array(results_y, backend="jax", dtype=complex)) + results = OdeResult(t=results_t, y=results_y) if y0_cls is not None and convert_results: results.y = [state_type_wrapper(yi) for yi in results.y] @@ -698,7 +713,7 @@ def _schedule_to_signals(self, schedule: Schedule): ) -def initial_state_converter(obj: Any) -> Tuple[Array, Type, Callable]: +def initial_state_converter(obj: Any) -> Tuple[ArrayLike, Type, Callable]: """Convert initial state object to an Array, the type of the initial input, and return function for constructing a state of the same type. @@ -710,23 +725,23 @@ def initial_state_converter(obj: Any) -> Tuple[Array, Type, Callable]: """ # pylint: disable=invalid-name y0_cls = None - if isinstance(obj, Array): + if isinstance(obj, ArrayLike): y0, y0_cls, wrapper = obj, None, lambda x: x if isinstance(obj, QuantumState): - y0, y0_cls = Array(obj.data), obj.__class__ + y0, y0_cls = obj.data, obj.__class__ wrapper = lambda x: y0_cls(np.array(x), dims=obj.dims()) elif isinstance(obj, QuantumChannel): - y0, y0_cls = Array(SuperOp(obj).data), SuperOp + y0, y0_cls = SuperOp(obj).data, SuperOp wrapper = lambda x: SuperOp( np.array(x), input_dims=obj.input_dims(), output_dims=obj.output_dims() ) elif isinstance(obj, (BaseOperator, Gate, QuantumCircuit)): - y0, y0_cls = Array(Operator(obj.data)), Operator + y0, y0_cls = Operator(obj.data), Operator wrapper = lambda x: Operator( np.array(x), input_dims=obj.input_dims(), output_dims=obj.output_dims() ) else: - y0, y0_cls, wrapper = Array(obj), None, lambda x: x + y0, y0_cls, wrapper = unp.asarray(obj), None, lambda x: x return y0, y0_cls, wrapper @@ -758,8 +773,8 @@ def validate_and_format_initial_state(y0: any, model: Union[HamiltonianModel, Li # validate types if (y0_cls is SuperOp) and is_lindblad_model_not_vectorized(model): raise QiskitError( - """Simulating SuperOp for a LindbladModel requires setting - vectorized evaluation. Set LindbladModel.evaluation_mode to a vectorized option. + """Simulating SuperOp for a LindbladModel requires setting vectorized evaluation. + Set vectorized=True when constructing LindbladModel. """ ) @@ -767,11 +782,7 @@ def validate_and_format_initial_state(y0: any, model: Union[HamiltonianModel, Li if y0_cls in [DensityMatrix, SuperOp] and isinstance(model, HamiltonianModel): y0 = np.eye(model.dim, dtype=complex) # if LindbladModel is vectorized and simulating a density matrix, flatten - elif ( - (y0_cls is DensityMatrix) - and isinstance(model, LindbladModel) - and "vectorized" in model.evaluation_mode - ): + elif (y0_cls is DensityMatrix) and is_lindblad_model_vectorized(model): y0 = y0.flatten(order="F") # validate y0 shape before passing to solve_lmde @@ -794,7 +805,7 @@ def validate_and_format_initial_state(y0: any, model: Union[HamiltonianModel, Li def format_final_states(y, model, y0_input, y0_cls): """Format final states for a single simulation.""" - y = Array(y) + y = unp.asarray(y) if y0_cls is DensityMatrix and isinstance(model, HamiltonianModel): # conjugate by unitary @@ -802,9 +813,9 @@ def format_final_states(y, model, y0_input, y0_cls): elif y0_cls is SuperOp and isinstance(model, HamiltonianModel): # convert to SuperOp and compose return ( - np.einsum("nka,nlb->nklab", y.conj(), y).reshape( - y.shape[0], y.shape[1] ** 2, y.shape[1] ** 2 - ) + numpy_alias(like=y) + .einsum("nka,nlb->nklab", y.conj(), y) + .reshape(y.shape[0], y.shape[1] ** 2, y.shape[1] ** 2) @ y0_input ) elif (y0_cls is DensityMatrix) and is_lindblad_model_vectorized(model): @@ -903,7 +914,7 @@ def _nested_ndim(x): """Determine the 'ndim' of x, which could be composed of nested lists and array types.""" if isinstance(x, (list, tuple)): return 1 + _nested_ndim(x[0]) - elif issubclass(type(x), Dispatch.REGISTERED_TYPES) or isinstance(x, Array): + elif hasattr(x, "ndim"): return x.ndim # assume scalar diff --git a/qiskit_dynamics/solvers/solver_functions.py b/qiskit_dynamics/solvers/solver_functions.py index 5143eb4cb..48cc6a102 100644 --- a/qiskit_dynamics/solvers/solver_functions.py +++ b/qiskit_dynamics/solvers/solver_functions.py @@ -17,15 +17,16 @@ Solver functions. """ -from typing import Optional, Union, Callable, Tuple, List, TypeVar +from typing import Optional, Union, Callable, Tuple, TypeVar from warnings import warn from scipy.integrate import OdeSolver -from scipy.integrate._ivp.ivp import OdeResult # pylint: disable=unused-import +from scipy.integrate._ivp.ivp import OdeResult from qiskit import QiskitError -from qiskit_dynamics.array import Array +from qiskit_dynamics import DYNAMICS_NUMPY as unp +from qiskit_dynamics.arraylias import ArrayLike from qiskit_dynamics.models import ( BaseGeneratorModel, @@ -95,18 +96,20 @@ def _is_diffrax_method(method: any) -> bool: def _lanczos_validation( rhs: Union[Callable, BaseGeneratorModel], - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, k_dim: int, ): """Validation checks to run lanczos based solvers.""" + t_span = unp.asarray(t_span) + y0 = unp.asarray(y0) if isinstance(rhs, BaseGeneratorModel): if not isinstance(rhs, HamiltonianModel): raise QiskitError( """Lanczos solver can only be used for HamiltonianModel or function-based anti-Hermitian generators.""" ) - if "sparse" not in rhs.evaluation_mode: + if "sparse" not in rhs.array_library: warn( """lanczos_diag should be used with a generator in sparse mode for better performance.""", @@ -124,12 +127,12 @@ def _lanczos_validation( def solve_ode( rhs: Union[Callable, BaseGeneratorModel], - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, method: Optional[Union[str, OdeSolver, DiffraxAbstractSolver]] = "DOP853", - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, **kwargs, -): +) -> OdeResult: r"""General interface for solving Ordinary Differential Equations (ODEs). ODEs are differential equations of the form @@ -181,7 +184,7 @@ def solve_ode( ): raise QiskitError("Method " + str(method) + " not supported by solve_ode.") - y0 = Array(y0) + y0 = unp.asarray(y0) if isinstance(rhs, BaseGeneratorModel): _, solver_rhs, y0, model_in_frame_basis = setup_generator_model_rhs_y0_in_frame_basis( @@ -205,7 +208,7 @@ def solve_ode( # convert results out of frame basis if necessary if isinstance(rhs, BaseGeneratorModel): if not model_in_frame_basis: - results.y = results_y_out_of_frame_basis(rhs, Array(results.y), y0.ndim) + results.y = results_y_out_of_frame_basis(rhs, results.y, y0.ndim) # convert model back to original basis rhs.in_frame_basis = model_in_frame_basis @@ -215,12 +218,12 @@ def solve_ode( def solve_lmde( generator: Union[Callable, BaseGeneratorModel], - t_span: Array, - y0: Array, + t_span: ArrayLike, + y0: ArrayLike, method: Optional[Union[str, OdeSolver, DiffraxAbstractSolver]] = "DOP853", - t_eval: Optional[Union[Tuple, List, Array]] = None, + t_eval: Optional[ArrayLike] = None, **kwargs, -): +) -> OdeResult: r"""General interface for solving Linear Matrix Differential Equations (LMDEs) in standard form. @@ -241,7 +244,7 @@ def solve_lmde( Not all model classes are by-default in standard form. E.g. :class:`~qiskit_dynamics.models.LindbladModel` represents an LMDE which is not typically written in standard form. As such, using LMDE-specific methods with this generator requires - setting a vectorized evaluation mode. + the equation to be vectorized. The ``method`` argument exposes solvers specialized to both LMDEs, as well as general ODE solvers. If the method is not specific to LMDEs, the problem will be passed to @@ -270,12 +273,12 @@ def solve_lmde( behaviour. Note that this method contains calls to ``jax.numpy.eigh``, which may have limited validity when automatically differentiated. - ``'jax_expm'``: JAX-implemented version of ``'scipy_expm'``, with the same arguments and - behaviour. Note that this method cannot be used for a model in sparse evaluation mode. + behaviour. Note that this method cannot be used for a model using a sparse array library. - ``'jax_expm_parallel'``: Same as ``'jax_expm'``, however all loops are implemented using parallel operations. I.e. all matrix-exponentials for taking a single step are computed in parallel using ``jax.vmap``, and are subsequently multiplied together in parallel using ``jax.lax.associative_scan``. This method is only recommended for use with GPU execution. Note - that this method cannot be used for a model in sparse evaluation mode. + that this method cannot be used for a model using a sparse array library. - ``'jax_RK4_parallel'``: 4th order Runge-Kutta fixed step solver. Under the assumption of the structure of an LMDE, utilizes the same parallelization approach as ``'jax_expm_parallel'``, however the single step rule is the standard 4th order Runge-Kutta rule, rather than @@ -303,8 +306,8 @@ def solve_lmde( Additional Information: While all :class:`~qiskit_dynamics.models.BaseGeneratorModel` subclasses represent LMDEs, they are not all in standard form by defualt. Using an LMDE-specific models like - :class:`~qiskit_dynamics.models.LindbladModel` requires first setting a vectorized - evaluation mode. + :class:`~qiskit_dynamics.models.LindbladModel` requires first setting the model to be + vectorized. """ # delegate to solve_ode if necessary @@ -329,11 +332,10 @@ def rhs(t, y): # lmde-specific methods can't be used with LindbladModel unless vectorized if is_lindblad_model_not_vectorized(generator): raise QiskitError( - """LMDE-specific methods with LindbladModel requires setting a - vectorized evaluation mode.""" + "LMDE-specific methods with LindbladModel requires setting a vectorized=True." ) - y0 = Array(y0) + y0 = unp.asarray(y0) # setup generator and rhs functions to pass to numerical methods if isinstance(generator, BaseGeneratorModel): @@ -352,7 +354,7 @@ def rhs(t, y): elif method == "jax_lanczos_diag": results = jax_lanczos_diag_solver(solver_generator, t_span, y0, t_eval=t_eval, **kwargs) elif method == "jax_expm": - if isinstance(generator, BaseGeneratorModel) and "sparse" in generator.evaluation_mode: + if isinstance(generator, BaseGeneratorModel) and "sparse" in generator.array_library: raise QiskitError("jax_expm cannot be used with a generator in sparse mode.") results = jax_expm_solver(solver_generator, t_span, y0, t_eval=t_eval, **kwargs) elif method == "jax_expm_parallel": @@ -363,7 +365,7 @@ def rhs(t, y): # convert results to correct basis if necessary if isinstance(generator, BaseGeneratorModel): if not model_in_frame_basis: - results.y = results_y_out_of_frame_basis(generator, Array(results.y), y0.ndim) + results.y = results_y_out_of_frame_basis(generator, results.y, y0.ndim) generator.in_frame_basis = model_in_frame_basis @@ -371,8 +373,8 @@ def rhs(t, y): def setup_generator_model_rhs_y0_in_frame_basis( - generator_model: BaseGeneratorModel, y0: Array -) -> Tuple[Callable, Callable, Array]: + generator_model: BaseGeneratorModel, y0: ArrayLike +) -> Tuple[Callable, Callable, ArrayLike]: """Helper function for setting up a subclass of :class:`~qiskit_dynamics.models.BaseGeneratorModel` to be solved in the frame basis. @@ -391,10 +393,7 @@ def setup_generator_model_rhs_y0_in_frame_basis( # if model not specified in frame basis, transform initial state into frame basis if not model_in_frame_basis: - if ( - isinstance(generator_model, LindbladModel) - and "vectorized" in generator_model.evaluation_mode - ): + if isinstance(generator_model, LindbladModel) and generator_model.vectorized: if generator_model.rotating_frame.frame_basis is not None: y0 = generator_model.rotating_frame.vectorized_frame_basis_adjoint @ y0 elif isinstance(generator_model, LindbladModel): @@ -416,8 +415,8 @@ def rhs(t, y): def results_y_out_of_frame_basis( - generator_model: BaseGeneratorModel, results_y: Array, y0_ndim: int -) -> Array: + generator_model: BaseGeneratorModel, results_y: ArrayLike, y0_ndim: int +) -> ArrayLike: """Convert the results of a simulation for :class:`~qiskit_dynamics.models.BaseGeneratorModel` out of the frame basis. @@ -436,10 +435,7 @@ def results_y_out_of_frame_basis( if y0_ndim == 1: results_y = results_y.T - if ( - isinstance(generator_model, LindbladModel) - and "vectorized" in generator_model.evaluation_mode - ): + if isinstance(generator_model, LindbladModel) and generator_model.vectorized: if generator_model.rotating_frame.frame_basis is not None: results_y = generator_model.rotating_frame.vectorized_frame_basis @ results_y elif isinstance(generator_model, LindbladModel): diff --git a/qiskit_dynamics/solvers/solver_utils.py b/qiskit_dynamics/solvers/solver_utils.py index 0e7a47290..c2e6d9ce8 100644 --- a/qiskit_dynamics/solvers/solver_utils.py +++ b/qiskit_dynamics/solvers/solver_utils.py @@ -17,13 +17,13 @@ Utility functions for solvers. """ -from typing import Optional, Union, List, Tuple, Callable +from typing import Optional, List, Tuple, Callable import numpy as np from scipy.integrate._ivp.ivp import OdeResult from qiskit import QiskitError -from qiskit_dynamics.array import Array +from qiskit_dynamics.arraylias import ArrayLike from qiskit_dynamics.models import LindbladModel try: @@ -35,17 +35,15 @@ def is_lindblad_model_vectorized(obj: any) -> bool: """Return True if obj is a vectorized LindbladModel.""" - return isinstance(obj, LindbladModel) and ("vectorized" in obj.evaluation_mode) + return isinstance(obj, LindbladModel) and obj.vectorized def is_lindblad_model_not_vectorized(obj: any) -> bool: """Return True if obj is a non-vectorized LindbladModel.""" - return isinstance(obj, LindbladModel) and ("vectorized" not in obj.evaluation_mode) + return isinstance(obj, LindbladModel) and not obj.vectorized -def merge_t_args( - t_span: Union[List, Tuple, Array], t_eval: Optional[Union[List, Tuple, Array]] = None -) -> Union[List, Tuple, Array]: +def merge_t_args(t_span: ArrayLike, t_eval: Optional[ArrayLike] = None) -> np.ndarray: """Merge ``t_span`` and ``t_eval`` into a single array. Validition is similar to scipy ``solve_ivp``: ``t_eval`` must be contained in ``t_span``, and be @@ -61,7 +59,7 @@ def merge_t_args( t_eval: Time points to include in returned results. Returns: - Union[List, Tuple, Array]: Combined list of times. + np.ndarray: Combined list of times. Raises: ValueError: If one of several validation checks fail. @@ -70,13 +68,13 @@ def merge_t_args( if t_eval is None: return t_span - t_span = Array(t_span, backend="numpy") + t_span = np.array(t_span) t_min = np.min(t_span) t_max = np.max(t_span) t_direction = np.sign(t_span[1] - t_span[0]) - t_eval = Array(t_eval, backend="numpy") + t_eval = np.array(t_eval) if t_eval.ndim > 1: raise ValueError("t_eval must be 1 dimensional.") @@ -92,12 +90,12 @@ def merge_t_args( # add endpoints t_eval = np.append(np.append(t_span[0], t_eval), t_span[1]) - return Array(t_eval, backend="numpy") + return t_eval def trim_t_results( results: OdeResult, - t_eval: Optional[Union[List, Tuple, Array]] = None, + t_eval: Optional[ArrayLike] = None, ) -> OdeResult: """Trim ``OdeResult`` object if ``t_eval is not None``. @@ -116,14 +114,12 @@ def trim_t_results( # remove endpoints results.t = results.t[1:-1] - results.y = Array(results.y[1:-1]) + results.y = results.y[1:-1] return results -def merge_t_args_jax( - t_span: Union[List, Tuple, Array], t_eval: Optional[Union[List, Tuple, Array]] = None -) -> Union[List, Tuple, Array]: +def merge_t_args_jax(t_span: ArrayLike, t_eval: Optional[ArrayLike] = None) -> jnp.ndarray: """JAX-compilable version of merge_t_args. Rather than raise errors, sets return values to ``jnp.nan`` to signal errors. @@ -139,17 +135,17 @@ def merge_t_args_jax( t_eval: Time points to include in returned results. Returns: - Union[List, Tuple, Array]: Combined list of times. + jnp.ndarray: Combined list of times. Raises: ValueError: If either argument is not one dimensional. """ if t_eval is None: - return Array(t_span, backend="jax") + return jnp.array(t_span) - t_span = Array(t_span, backend="jax").data - t_eval = Array(t_eval, backend="jax").data + t_span = jnp.array(t_span) + t_eval = jnp.array(t_eval) # raise error if not one dimensional if t_eval.ndim > 1: @@ -178,12 +174,12 @@ def merge_t_args_jax( # if out[-1] == out[-2], set out[-2] == (out[-3] + out[-1])/2 out = cond(out[-1] == out[-2], lambda x: x.at[-2].set((x[-3] + x[-1]) / 2), lambda x: x, out) - return Array(out) + return out def trim_t_results_jax( results: OdeResult, - t_eval: Optional[Union[List, Tuple, Array]] = None, + t_eval: Optional[ArrayLike] = None, ) -> OdeResult: """JAX-compilable version of trim_t_results. @@ -203,35 +199,29 @@ def trim_t_results_jax( if t_eval is not None: # remove second entry if t_eval[0] == results.t[0], as this indicates a repeated time - results.y = Array( - cond( - t_eval[0] == results.t[0], - lambda y: jnp.append(jnp.array([y[0]]), y[2:], axis=0), - lambda y: y[1:], - Array(results.y).data, - ) + results.y = cond( + t_eval[0] == results.t[0], + lambda y: jnp.append(jnp.array([y[0]]), y[2:], axis=0), + lambda y: y[1:], + jnp.array(results.y), ) # remove second last entry if t_eval[-1] == results.t[-1], as this indicates a repeated time - results.y = Array( - cond( - t_eval[-1] == results.t[-1], - lambda y: jnp.append(y[:-2], jnp.array([y[-1]]), axis=0), - lambda y: y[:-1], - Array(results.y).data, - ) + results.y = cond( + t_eval[-1] == results.t[-1], + lambda y: jnp.append(y[:-2], jnp.array([y[-1]]), axis=0), + lambda y: y[:-1], + jnp.array(results.y), ) - results.t = Array(t_eval) + results.t = t_eval # this handles the odd case that t_span == [a, a] - results.y = Array( - cond( - results.t[0] == results.t[-1], - lambda y: y.at[-1].set(y[0]), - lambda y: y, - Array(results.y).data, - ) + results.y = cond( + results.t[0] == results.t[-1], + lambda y: y.at[-1].set(y[0]), + lambda y: y, + jnp.array(results.y), ) return results diff --git a/releasenotes/notes/solver-class-arraylias-80781bb527c3bf52.yaml b/releasenotes/notes/solver-class-arraylias-80781bb527c3bf52.yaml new file mode 100644 index 000000000..3c792f92a --- /dev/null +++ b/releasenotes/notes/solver-class-arraylias-80781bb527c3bf52.yaml @@ -0,0 +1,15 @@ +--- +upgrade: + - | + In conjunction with the change to the ``evaluation_mode`` argument in the model classes, the + :class:`.Solver` class has been updated to take the ``array_library`` constructor argument, as + well as the ``vectorized`` constructor argument (for use when Lindblad terms are present). + - | + The logic in :meth:`.Solver.solve` for automatic ``jit`` compiling when using JAX and simulating + a list of schedules has been updated to no longer be based on when + ``Array.default_backend() == "jax"``. The attempted automatic ``jit`` compiling in this case + is now based only when whether either ``method="jax_odeint"``, or ``method`` is a Diffrax + integration method. A warning will be raised if the ``array_library`` is not known to be + compatible with the compilation routine. (For now, ``"scipy_sparse"`` is the only + ``array_library`` not compatible with this routine, however a warning will still be raised if + no explicit ``array_library`` is provided, as in this case the JAX-compatibility is unknown.) \ No newline at end of file diff --git a/test/dynamics/common.py b/test/dynamics/common.py index a08883428..89ad86e77 100644 --- a/test/dynamics/common.py +++ b/test/dynamics/common.py @@ -312,6 +312,12 @@ def setUpClass(cls): with warnings.catch_warnings(): warnings.simplefilter("ignore") import diffrax # pylint: disable=import-outside-toplevel,unused-import + + # pylint: disable=import-outside-toplevel + import jax + + jax.config.update("jax_enable_x64", True) + jax.config.update("jax_platform_name", "cpu") except Exception as err: raise unittest.SkipTest("Skipping diffrax tests.") from err diff --git a/test/dynamics/solvers/test_diffrax_DOP5.py b/test/dynamics/solvers/test_diffrax_DOP5.py index de96ea3eb..3a0870b75 100644 --- a/test/dynamics/solvers/test_diffrax_DOP5.py +++ b/test/dynamics/solvers/test_diffrax_DOP5.py @@ -19,7 +19,7 @@ from qiskit_dynamics.solvers.diffrax_solver import diffrax_solver -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import JAXTestBase try: import jax.numpy as jnp @@ -30,7 +30,7 @@ pass -class TestDiffraxDopri5(QiskitDynamicsTestCase, TestJaxBase): +class TestDiffraxDopri5(JAXTestBase): """Test cases for diffrax_solver.""" def setUp(self): diff --git a/test/dynamics/solvers/test_fixed_step_solvers.py b/test/dynamics/solvers/test_fixed_step_solvers.py index 8f7196dac..f8e69a2ae 100644 --- a/test/dynamics/solvers/test_fixed_step_solvers.py +++ b/test/dynamics/solvers/test_fixed_step_solvers.py @@ -21,7 +21,7 @@ from scipy.linalg import expm -from qiskit_dynamics.array import Array +from qiskit_dynamics import DYNAMICS_NUMPY as unp from qiskit_dynamics.solvers.fixed_step_solvers import ( RK4_solver, scipy_expm_solver, @@ -37,10 +37,11 @@ jax_lanczos_expm, ) -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import QiskitDynamicsTestCase, JAXTestBase try: from jax.scipy.linalg import expm as jexpm + from jax import jit, grad # pylint: disable=broad-except except Exception: pass @@ -56,7 +57,8 @@ def setUp(self): """Setup RHS functions for testing of fixed step solvers. Constructed as LMDEs so that the tests can be used for both LMDE and ODE methods. """ - self.constant_generator = lambda t: -1j * Array([[0.0, 1.0], [1.0, 0.0]]).data + X = np.array([[0.0, 1.0], [1.0, 0.0]]) + self.constant_generator = lambda t: -1j * X def constant_rhs(t, y=None): if y is None: @@ -66,9 +68,8 @@ def constant_rhs(t, y=None): self.constant_rhs = constant_rhs - self.linear_generator = ( - lambda t: -1j * Array([[0.0, 1.0 - 1j * t], [1.0 + 1j * t, 0.0]]).data - ) + Y = np.array([[0.0, -1j], [1j, 0.0]]) + self.linear_generator = lambda t: -1j * (X + t * Y) def linear_rhs(t, y=None): if y is None: @@ -96,9 +97,7 @@ def linear_rhs(t, y=None): ) def random_generator(t): - t = Array(t) - output = np.sin(t) * rand_ops[0] + (t**5) * rand_ops[1] + np.exp(t) * rand_ops[2] - return Array(output).data + return unp.sin(t) * rand_ops[0] + (t**5) * rand_ops[1] + unp.exp(t) * rand_ops[2] self.random_generator = random_generator @@ -453,8 +452,8 @@ def test_case_iz(self): self.assertAllClose(result2[-1], expm(gen(0) * t_span[-1]) @ y02) -class TestJaxFixedStepBase(TestFixedStepBase, TestJaxBase): - """JAX version of TestFixedStepBase, adding JAX setup class TestJaxBase, +class TestJaxFixedStepBase(TestFixedStepBase, JAXTestBase): + """JAX version of TestFixedStepBase, adding JAX setup class JAXTestBase, and adding jit/grad test. """ @@ -466,14 +465,14 @@ def func(amp): results = self.solve( lambda *args: amp * self.constant_rhs(*args), t_span, self.id2, max_dt=0.1 ) - return Array(results.y[-1]).data + return results.y[-1] - jit_func = self.jit_wrap(func) + jit_func = jit(func) output = jit_func(1.0) expected_y = self.take_n_steps(self.constant_rhs, t=0.0, y=self.id2, h=0.1, n_steps=10) self.assertAllClose(expected_y, output) - grad_func = self.jit_grad_wrap(func) + grad_func = jit(grad(lambda *args: func(*args).real.sum())) grad_func(1.0) diff --git a/test/dynamics/solvers/test_jax_odeint.py b/test/dynamics/solvers/test_jax_odeint.py index 1ad32ffc7..095fa4454 100644 --- a/test/dynamics/solvers/test_jax_odeint.py +++ b/test/dynamics/solvers/test_jax_odeint.py @@ -19,17 +19,18 @@ from qiskit_dynamics.solvers.jax_odeint import jax_odeint -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import JAXTestBase try: import jax.numpy as jnp from jax.lax import cond + from jax import jit, grad # pylint: disable=broad-except except Exception: pass -class TestJaxOdeint(QiskitDynamicsTestCase, TestJaxBase): +class TestJaxOdeint(JAXTestBase): """Test cases for jax_odeint.""" def setUp(self): @@ -129,15 +130,13 @@ def test_transformations_w_t_span_t_eval_no_overlap(self): t_span = np.array([0.0, 2.0]) t_eval = np.array([1.0, 1.5, 1.7]) - y0 = jnp.array([1.0]) + y0 = jnp.array([1.0], dtype=complex) def func(t_s, t_e): results = jax_odeint(self.simple_rhs, t_s, y0, t_eval=t_e, atol=1e-10, rtol=1e-10) - return results.t.data, results.y.data - - jit_func = self.jit_wrap(func) + return results.t, results.y - t, y = jit_func(t_span, t_eval) + t, y = jit(func)(t_span, t_eval) self.assertAllClose(t_eval, t) @@ -150,7 +149,7 @@ def func(t_s, t_e): ) self.assertAllClose(expected_y, y) - jit_grad_func = self.jit_grad_wrap(lambda a: func(t_span, a)[1][-1]) + jit_grad_func = jit(grad(lambda a: func(t_span, a)[1][-1].real.sum())) out = jit_grad_func(t_eval) self.assertAllClose(out, np.array([0.0, 0.0, 1.7**2])) @@ -167,7 +166,8 @@ def sim_function(a): return results.y[-1].real.sum() self.assertAllClose( - self.jit_grad_wrap(sim_function)(2.0), 4 * (0.5 + (2.0**3 - 1.0**3) / 3) + jit(grad(lambda a: sim_function(a).real.sum()))(2.0), + 4 * (0.5 + (2.0**3 - 1.0**3) / 3), ) def test_empty_integration(self): diff --git a/test/dynamics/solvers/test_lanczos.py b/test/dynamics/solvers/test_lanczos.py index 3053d8d0f..05a6fc598 100644 --- a/test/dynamics/solvers/test_lanczos.py +++ b/test/dynamics/solvers/test_lanczos.py @@ -24,7 +24,7 @@ jax_lanczos_eigh, jax_lanczos_expm, ) -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import QiskitDynamicsTestCase, JAXTestBase class TestLanczos(QiskitDynamicsTestCase): @@ -77,7 +77,7 @@ def test_expm(self): self.assertAllClose(expAy_s, expAy_l) -class TestJaxLanczos(TestLanczos, TestJaxBase): +class TestJaxLanczos(TestLanczos, JAXTestBase): """Tests for jax functions in lanczos.py.""" def setUp(self): diff --git a/test/dynamics/solvers/test_solver_classes.py b/test/dynamics/solvers/test_solver_classes.py index 86f24884b..e75f0d430 100644 --- a/test/dynamics/solvers/test_solver_classes.py +++ b/test/dynamics/solvers/test_solver_classes.py @@ -9,12 +9,14 @@ # Any modifications or derivative works of this code must retain this # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. -# pylint: disable=invalid-name +# pylint: disable=invalid-name,no-member """ Tests for solver classes module. """ +from functools import partial + import numpy as np import sympy as sym from ddt import ddt, data, unpack @@ -24,10 +26,14 @@ from qiskit_dynamics import Solver, Signal, DiscreteSignal, solve_lmde from qiskit_dynamics.models import HamiltonianModel, LindbladModel, rotating_wave_approximation -from qiskit_dynamics.type_utils import to_array from qiskit_dynamics.solvers.solver_classes import organize_signals_to_channels -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import QiskitDynamicsTestCase, test_array_backends + +try: + from jax import jit +except ImportError: + pass class TestSolverValidation(QiskitDynamicsTestCase): @@ -189,7 +195,7 @@ def setUp(self): self.vec_lindblad_solver = Solver( hamiltonian_operators=[X], static_dissipators=[X], - evaluation_mode="dense_vectorized", + vectorized=True, ) def test_hamiltonian_shape_error(self): @@ -451,7 +457,8 @@ def test_signals_are_None(self): self.assertTrue(td_lindblad_solver.model.signals == (None, None)) -class TestSolverSimulation(QiskitDynamicsTestCase): +@partial(test_array_backends, array_libraries=["numpy", "jax"]) +class TestSolverSimulation: """Test cases for correct simulation for Solver class.""" def setUp(self): @@ -486,7 +493,7 @@ def setUp(self): static_dissipators=[0.01 * X], static_hamiltonian=5 * Z, rotating_frame=5 * Z, - evaluation_mode="dense_vectorized", + vectorized=True, ) # lindblad solver with no dissipation for testing @@ -495,9 +502,13 @@ def setUp(self): static_dissipators=[0.0 * X], static_hamiltonian=5 * Z, rotating_frame=5 * Z, - evaluation_mode="dense_vectorized", + vectorized=True, ) - self.method = "DOP853" + + if self.array_library() == "numpy": + self.method = "DOP853" + elif self.array_library() == "jax": + self.method = "jax_odeint" def test_state_dims_preservation(self): """Test that state shapes are correctly preserved.""" @@ -527,7 +538,7 @@ def test_state_dims_preservation(self): self.assertTrue(yf.dims() == (2, 3)) # SuperOp - solver.model.evaluation_mode = "dense_vectorized" + solver = Solver(static_dissipators=np.zeros((1, 6, 6)), vectorized=True) y0 = SuperOp(np.eye(36), input_dims=(2, 3), output_dims=(3, 2)) yf = solver.solve(t_span=[0.0, 0.1], y0=y0).y[-1] self.assertTrue(isinstance(yf, SuperOp)) @@ -685,40 +696,22 @@ def test_lindblad_solver_consistency(self): self.assertTrue(results.y[-1].data[0, 0] > 0.99 and results.y[-1].data[0, 0] < 0.999) -class TestSolverSimulationJax(TestSolverSimulation, TestJaxBase): - """JAX version of TestSolverSimulation.""" +@partial(test_array_backends, array_libraries=["jax", "jax_sparse"]) +class TestSolverSimulationJAXTransformations: + """Test Solver class within JAX transformations.""" def setUp(self): - """Set method to 'jax_odeint' to speed up running of jax version of tests.""" - super().setUp() - self.method = "jax_odeint" - - def test_transform_through_construction_when_validate_false(self): - """Test that a function building a Solver can be compiled if validate=False.""" - - Z = to_array(self.Z) - X = to_array(self.X) - - def func(a): - solver = Solver( - static_hamiltonian=5 * Z, - hamiltonian_operators=[X], - rotating_frame=5 * Z, - validate=False, - ) - yf = solver.solve( - t_span=np.array([0.0, 0.1]), - y0=np.array([0.0, 1.0]), - signals=[Signal(a, 5.0)], - method=self.method, - ).y[-1] - return yf - - jit_func = self.jit_wrap(func) - self.assertAllClose(jit_func(2.0), func(2.0)) - - jit_grad_func = self.jit_grad_wrap(func) - jit_grad_func(1.0) + """Set up some simple models.""" + X = 2 * np.pi * Operator.from_label("X") / 2 + Z = 2 * np.pi * Operator.from_label("Z") / 2 + self.X = X + self.Z = Z + self.ham_solver = Solver( + hamiltonian_operators=[X], + static_hamiltonian=5 * Z, + rotating_frame=5 * Z, + array_library=self.array_library(), + ) def test_jit_solve(self): """Test jitting setting signals and solving.""" @@ -728,12 +721,11 @@ def func(a): t_span=np.array([0.0, 1.0]), y0=np.array([0.0, 1.0]), signals=[Signal(lambda t: a, 5.0)], - method=self.method, + method="jax_odeint", ).y[-1] return yf - jit_func = self.jit_wrap(func) - self.assertAllClose(jit_func(2.0), func(2.0)) + self.assertAllClose(jit(func)(2.0), func(2.0)) def test_jit_grad_solve(self): """Test jitting setting signals and solving.""" @@ -746,16 +738,48 @@ def func(a): t_span=[0.0, 1.0], y0=np.array([[0.0, 1.0], [0.0, 1.0]], dtype=complex), signals=([Signal(lambda t: a, 5.0)], [1.0]), - method=self.method, + method="jax_odeint", + ).y[-1] + return yf + + self.jit_grad(func)(1.0) + + +@partial(test_array_backends, array_libraries=["jax"]) +class TestSolverConstructionJAXTransformations: + """Test construction of Solver within a function to be JAX transformed.""" + + def test_transform_through_construction_when_validate_false(self): + """Test that a function building a Solver can be compiled if validate=False.""" + + Z = np.array([[1.0, 0.0], [0.0, -1.0]]) + X = np.array([[0.0, 1.0], [1.0, 0.0]]) + + def func(a): + solver = Solver( + static_hamiltonian=5 * Z, + hamiltonian_operators=[X], + rotating_frame=5 * Z, + validate=False, + array_library=self.array_library(), + ) + yf = solver.solve( + t_span=np.array([0.0, 0.1]), + y0=np.array([0.0, 1.0]), + signals=[Signal(a, 5.0)], + method="jax_odeint", ).y[-1] return yf - jit_grad_func = self.jit_grad_wrap(func) - jit_grad_func(1.0) + self.assertAllClose(jit(func)(2.0), func(2.0)) + + # validate that grad can be compiled and evaluated + self.jit_grad(func)(1.0) +@partial(test_array_backends, array_libraries=["numpy", "jax"]) @ddt -class TestPulseSimulation(QiskitDynamicsTestCase): +class TestPulseSimulation: """Test simulation of pulse schedules.""" def setUp(self): @@ -765,7 +789,12 @@ def setUp(self): self.X = X self.Z = Z - self.static_ham_solver = Solver(static_hamiltonian=5 * Z, rotating_frame=5 * Z, dt=0.1) + self.static_ham_solver = Solver( + static_hamiltonian=5 * Z, + rotating_frame=5 * Z, + dt=0.1, + array_library=self.array_library(), + ) self.ham_solver = Solver( hamiltonian_operators=[X], @@ -774,10 +803,15 @@ def setUp(self): hamiltonian_channels=["d0"], channel_carrier_freqs={"d0": 5.0}, dt=0.1, + array_library=self.array_library(), ) self.static_lindblad_solver = Solver( - static_dissipators=[0.01 * X], static_hamiltonian=5 * Z, rotating_frame=5 * Z, dt=0.1 + static_dissipators=[0.01 * X], + static_hamiltonian=5 * Z, + rotating_frame=5 * Z, + dt=0.1, + array_library=self.array_library(), ) self.lindblad_solver = Solver( @@ -788,6 +822,7 @@ def setUp(self): hamiltonian_channels=["d0"], channel_carrier_freqs={"d0": 5.0}, dt=0.1, + array_library=self.array_library(), ) self.ham_solver_2_channels = Solver( @@ -797,6 +832,7 @@ def setUp(self): hamiltonian_channels=["d0", "d1"], channel_carrier_freqs={"d0": 5.0, "d1": 3.1}, dt=0.1, + array_library=self.array_library(), ) self.td_lindblad_solver = Solver( @@ -809,10 +845,14 @@ def setUp(self): dissipator_channels=["d1"], channel_carrier_freqs={"d0": 5.0, "d1": 3.1}, dt=0.1, - evaluation_mode="dense_vectorized", + array_library=self.array_library(), + vectorized=True, ) - self.method = "DOP853" + if self.array_library() == "numpy": + self.method = "DOP853" + elif self.array_library() == "jax": + self.method = "jax_odeint" @unpack @data(("static_ham_solver",), ("static_lindblad_solver",)) @@ -974,6 +1014,7 @@ def test_4_channel_schedule(self): dissipator_channels=["d1", "d3"], channel_carrier_freqs={"d0": 5.0, "d1": 3.1, "d2": 0, "d3": 4.0}, dt=dt, + array_library=self.array_library(), ) with pulse.build() as schedule: @@ -1026,6 +1067,7 @@ def test_rwa_ham_solver(self): channel_carrier_freqs={"d0": 5.0}, dt=0.1, rwa_cutoff_freq=1.5 * 5.0, + array_library=self.array_library(), ) ham_solver = Solver( @@ -1034,6 +1076,7 @@ def test_rwa_ham_solver(self): rotating_frame=5 * self.Z, rwa_cutoff_freq=1.5 * 5.0, rwa_carrier_freqs=[5.0], + array_library=self.array_library(), ) with pulse.build() as schedule: @@ -1072,6 +1115,7 @@ def test_rwa_lindblad_solver(self): channel_carrier_freqs={"d0": 5.0}, dt=0.1, rwa_cutoff_freq=1.5 * 5.0, + array_library=self.array_library(), ) lindblad_solver = Solver( @@ -1081,6 +1125,7 @@ def test_rwa_lindblad_solver(self): rotating_frame=5 * self.Z, rwa_cutoff_freq=1.5 * 5.0, rwa_carrier_freqs=[5.0], + array_library=self.array_library(), ) with pulse.build() as schedule: @@ -1225,12 +1270,24 @@ def _compare_schedule_to_signals( self.assertAllClose(pulse_res.y[-1], signal_res.y[-1], atol=test_tol, rtol=test_tol) -class TestPulseSimulationJAX(TestPulseSimulation, TestJaxBase): - """Test class for pulse simulation with JAX.""" +@partial(test_array_backends, array_libraries=["jax"]) +class TestPulseSimulationJAXPeculiarities: + """Test class for technical issues of pulse simulation with JAX.""" def setUp(self): - super().setUp() - self.method = "jax_odeint" + """Set up some simple models.""" + X = 2 * np.pi * Operator.from_label("X") / 2 + Z = 2 * np.pi * Operator.from_label("Z") / 2 + + self.ham_solver = Solver( + hamiltonian_operators=[X], + static_hamiltonian=5 * Z, + rotating_frame=5 * Z, + hamiltonian_channels=["d0"], + channel_carrier_freqs={"d0": 5.0}, + dt=0.1, + array_library=self.array_library(), + ) def test_t_eval_t_span_jax_odeint(self): """Test internal jitting works when specifying t_eval and t_span. This catches a bug @@ -1246,7 +1303,7 @@ def test_t_eval_t_span_jax_odeint(self): t_span=[0.0, 0.1], y0=np.array([0.0, 1.0]), t_eval=[0.0, 0.05, 0.1], - method=self.method, + method="jax_odeint", ) def test_t_eval_t_span_diffrax(self): @@ -1311,7 +1368,7 @@ def constant_pulse(amp): valid_amp_conditions=valid_amp_conditions_expr, ) - def jit_func(amp): + def func(amp): with pulse.build() as sched: pulse.play(constant_pulse(amp), pulse.DriveChannel(0)) @@ -1319,10 +1376,10 @@ def jit_func(amp): signals=sched, t_span=[0.0, 0.1], y0=np.array([0.0, 1.0]), - method=self.method, + method="jax_odeint", ) - self.jit_wrap(jit_func)(0.1) + jit(func)(0.1) @ddt diff --git a/test/dynamics/solvers/test_solver_functions.py b/test/dynamics/solvers/test_solver_functions.py index 66d8cdcdb..2bc49d112 100644 --- a/test/dynamics/solvers/test_solver_functions.py +++ b/test/dynamics/solvers/test_solver_functions.py @@ -9,7 +9,7 @@ # Any modifications or derivative works of this code must retain this # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. -# pylint: disable=invalid-name +# pylint: disable=invalid-name,no-member """Standardized test cases for results of calls to solve_lmde and solve_ode, for both variable and fixed-step methods. These tests set up common test cases @@ -19,17 +19,19 @@ test_solver_functions_interface.py """ +from functools import partial from abc import ABC, abstractmethod import numpy as np from scipy.linalg import expm +from qiskit_dynamics import DYNAMICS_NUMPY as unp + from qiskit_dynamics.models import GeneratorModel, HamiltonianModel from qiskit_dynamics.signals import Signal, DiscreteSignal from qiskit_dynamics import solve_ode, solve_lmde -from qiskit_dynamics.array import Array -from ..common import QiskitDynamicsTestCase, DiffraxTestBase, TestJaxBase +from ..common import test_array_backends, DiffraxTestBase try: from diffrax import PIDController, Tsit5, Dopri5 @@ -37,17 +39,18 @@ pass -class TestSolverMethod(ABC, QiskitDynamicsTestCase): - """Abstract base class for setting up models and RHS to be used in tests.""" +class TestSolverMethod(ABC): + """Abstract base class for setting up models and RHS to be used in tests. Note that this + assumes subclasses will use ``test_array_backends``.""" def setUp(self): """Construct standardized RHS functions and models.""" self.t_span = [0.0, 1.0] - self.y0 = Array(np.eye(2, dtype=complex)) + self.y0 = np.eye(2, dtype=complex) - self.X = Array([[0.0, 1.0], [1.0, 0.0]], dtype=complex) - self.Z = Array([[1.0, 0.0], [0.0, -1.0]], dtype=complex) + self.X = np.array([[0.0, 1.0], [1.0, 0.0]], dtype=complex) + self.Z = np.array([[1.0, 0.0], [0.0, -1.0]], dtype=complex) op = -1j * 2 * np.pi * self.X / 2 @@ -66,7 +69,9 @@ def basic_rhs(t, y=None): self.r = 0.1 signals = [self.w, Signal(lambda t: 1.0, self.w)] operators = [-1j * 2 * np.pi * self.Z / 2, -1j * 2 * np.pi * self.r * self.X / 2] - self.simple_model = GeneratorModel(operators=operators, signals=signals) + self.simple_model = GeneratorModel( + operators=operators, signals=signals, array_library=self.array_library() + ) # construct randomized RHS dim = 7 @@ -96,6 +101,7 @@ def basic_rhs(t, y=None): signals=[self.pseudo_random_signal], static_operator=static_operator, rotating_frame=frame_op, + array_library=self.array_library(), ) # simulate directly out of frame @@ -129,10 +135,11 @@ def test_ode_method(self): if self.is_ode_method: # pylint: disable=unused-argument def quad_rhs(t, y): - return np.real(Array([t**2])) + return unp.real(unp.array([t**2])) + + results = self.solve(quad_rhs, t_span=[0.0, 1.0], y0=np.array([0.0])) + expected = np.array([1.0 / 3]) - results = self.solve(quad_rhs, t_span=[0.0, 1.0], y0=Array([0.0])) - expected = Array([1.0 / 3]) self.assertAllClose(results.y[-1], expected) def test_basic_model_lmde_from_ode(self): @@ -143,7 +150,7 @@ def test_basic_model_lmde_from_ode(self): self.basic_rhs, t_span=self.t_span, y0=self.y0, solver_func=solve_lmde ) - expected = expm(-1j * np.pi * self.X.data) + expected = expm(-1j * np.pi * self.X) self.assertAllClose(results.y[-1], expected, atol=self.tol, rtol=self.tol) @@ -152,7 +159,7 @@ def test_basic_model(self): results = self.solve(self.basic_rhs, t_span=self.t_span, y0=self.y0) - expected = expm(-1j * np.pi * self.X.data) + expected = expm(-1j * np.pi * self.X) self.assertAllClose(results.y[-1], expected, atol=self.tol, rtol=self.tol) @@ -162,7 +169,7 @@ def test_backwards_solving(self): reverse_t_span = self.t_span.copy() reverse_t_span.reverse() - reverse_y0 = expm(-1j * np.pi * self.X.data) + reverse_y0 = expm(-1j * np.pi * self.X) results = self.solve(self.basic_rhs, t_span=reverse_t_span, y0=reverse_y0) @@ -173,7 +180,7 @@ def test_w_GeneratorModel(self): results = self.solve( self.simple_model, - y0=Array([0.0, 1.0], dtype=complex), + y0=np.array([0.0, 1.0], dtype=complex), t_span=[0, 1 / self.r], ) yf = results.y[-1] @@ -210,8 +217,10 @@ def test_pseudo_random_model(self): self.assertTrue(self.pseudo_random_model.in_frame_basis) -class TestSolverMethodJax(TestSolverMethod, TestJaxBase): - """JAX version of TestSolverMethod. Adds additional jit/grad test.""" +class TestSolverMethodJAX(TestSolverMethod): + """JAX version of TestSolverMethod. Adds additional jit/grad test. Assumes will be subclassed + using test_array_backends with JAX array librarys. + """ def test_pseudo_random_jit_grad(self): """Validate jitting and gradding through the method at the level of @@ -219,19 +228,23 @@ def test_pseudo_random_jit_grad(self): """ def func(a): - model_copy = self.pseudo_random_model.copy() - model_copy.signals = [Signal(a, carrier_freq=1.0)] - results = self.solve(model_copy, t_span=[0.0, 0.1], y0=self.pseudo_random_y0) + self.pseudo_random_model.signals = [Signal(a, carrier_freq=1.0)] + results = self.solve( + self.pseudo_random_model, t_span=[0.0, 0.1], y0=self.pseudo_random_y0 + ) + self.pseudo_random_model.signals = None return results.y[-1] - jit_func = self.jit_wrap(func) - self.assertAllClose(jit_func(1.0), func(1.0)) + # verify we can jit + from jax import jit + + self.assertAllClose(jit(func)(1.0), func(1.0)) # just verify that this runs without error - jit_grad_func = self.jit_grad_wrap(func) - jit_grad_func(1.0) + self.jit_grad(func)(1.0) +@partial(test_array_backends, array_libraries=["numpy"]) class TestRK4(TestSolverMethod): """Test class for RK4_solver.""" @@ -251,7 +264,8 @@ def is_ode_method(self): return True -class Testjax_RK4(TestSolverMethodJax): +@partial(test_array_backends, array_libraries=["jax"]) +class Testjax_RK4(TestSolverMethodJAX): """Test class for jax_RK4_solver.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_ode, **kwargs): @@ -270,7 +284,8 @@ def is_ode_method(self): return True -class Testjax_RK4_parallel(TestSolverMethodJax): +@partial(test_array_backends, array_libraries=["jax"]) +class Testjax_RK4_parallel(TestSolverMethodJAX): """Test class for jax_RK4_parallel_solver.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): @@ -290,6 +305,7 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): return results +@partial(test_array_backends, array_libraries=["numpy"]) class Testscipy_expm(TestSolverMethod): """Test class for scipy_expm_solver.""" @@ -305,6 +321,7 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): ) +@partial(test_array_backends, array_libraries=["numpy"]) class Testscipy_expm_magnus2(TestSolverMethod): """Test class for scipy_expm_solver with magnus_order==2.""" @@ -321,6 +338,7 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): ) +@partial(test_array_backends, array_libraries=["numpy"]) class Testscipy_expm_magnus3(TestSolverMethod): """Test class for scipy_expm_solver with magnus_order==3.""" @@ -337,6 +355,7 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): ) +# test_array_backends is called later as we need to subclass this class Testlanczos_diag(TestSolverMethod): """Test class for lanczos_diag_solver.""" @@ -344,13 +363,13 @@ def setUp(self): super().setUp() self.simple_model = HamiltonianModel( - operators=1j * self.simple_model.operators, + operators=1j * np.array([op.todense() for op in self.simple_model.operators]), signals=self.simple_model.signals, - evaluation_mode="sparse", + array_library=self.array_library(), ) - self.operators = self.pseudo_random_model.operators.data - self.static_operator = self.pseudo_random_model.static_operator.data + self.operators = np.array(self.pseudo_random_model.operators) + self.static_operator = np.array(self.pseudo_random_model.static_operator) # make hermitian self.operators = self.operators.conj().transpose(0, 2, 1) + self.operators @@ -362,7 +381,7 @@ def setUp(self): signals=[self.pseudo_random_signal], static_operator=self.static_operator, rotating_frame=self.frame_op, - evaluation_mode="sparse", + array_library=self.array_library(), ) # simulate directly out of frame @@ -389,7 +408,8 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): ) -class Testjax_lanczos_diag(Testlanczos_diag, TestSolverMethodJax): +@partial(test_array_backends, array_libraries=["jax_sparse"]) +class Testjax_lanczos_diag(Testlanczos_diag, TestSolverMethodJAX): """Test class for jax_expm_solver.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): @@ -405,7 +425,11 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): ) -class Testjax_expm(TestSolverMethodJax): +test_array_backends(Testlanczos_diag, array_libraries=["scipy_sparse"]) + + +@partial(test_array_backends, array_libraries=["jax"]) +class Testjax_expm(TestSolverMethodJAX): """Test class for jax_expm_solver.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): @@ -420,7 +444,8 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): ) -class Testjax_expm_magnus2(TestSolverMethodJax): +@partial(test_array_backends, array_libraries=["jax"]) +class Testjax_expm_magnus2(TestSolverMethodJAX): """Test class for jax_expm_solver with magnus_order==2.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): @@ -436,7 +461,8 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): ) -class Testjax_expm_magnus3(TestSolverMethodJax): +@partial(test_array_backends, array_libraries=["jax"]) +class Testjax_expm_magnus3(TestSolverMethodJAX): """Test class for jax_expm_solver with magnus_order==3.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): @@ -452,7 +478,8 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): ) -class Testjax_expm_parallel(TestSolverMethodJax): +@partial(test_array_backends, array_libraries=["jax"]) +class Testjax_expm_parallel(TestSolverMethodJAX): """Test class for jax_expm_parallel_solver.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): @@ -472,7 +499,8 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): return results -class Testjax_expm_parallel_magnus2(TestSolverMethodJax): +@partial(test_array_backends, array_libraries=["jax"]) +class Testjax_expm_parallel_magnus2(TestSolverMethodJAX): """Test class for jax_expm_parallel_solver with magnus_order==2.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): @@ -493,7 +521,8 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): return results -class Testjax_expm_parallel_magnus3(TestSolverMethodJax): +@partial(test_array_backends, array_libraries=["jax"]) +class Testjax_expm_parallel_magnus3(TestSolverMethodJAX): """Test class for jax_expm_parallel_solver with magnus_order==3.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): @@ -514,6 +543,7 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs): return results +@partial(test_array_backends, array_libraries=["numpy"]) class Testscipy_RK45(TestSolverMethod): """Tests for scipy solve_ivp RK45 method.""" @@ -534,6 +564,7 @@ def is_ode_method(self): return True +@partial(test_array_backends, array_libraries=["numpy"]) class Testscipy_RK23(TestSolverMethod): """Tests for scipy solve_ivp RK23 method.""" @@ -554,6 +585,7 @@ def is_ode_method(self): return True +@partial(test_array_backends, array_libraries=["numpy"]) class Testscipy_BDF(TestSolverMethod): """Tests for scipy solve_ivp BDF method.""" @@ -574,6 +606,7 @@ def is_ode_method(self): return True +@partial(test_array_backends, array_libraries=["numpy"]) class Testscipy_DOP853(TestSolverMethod): """Tests for scipy solve_ivp DOP853 method.""" @@ -594,7 +627,8 @@ def is_ode_method(self): return True -class Testjax_odeint(TestSolverMethodJax): +@partial(test_array_backends, array_libraries=["jax"]) +class Testjax_odeint(TestSolverMethodJAX): """Tests for jax odeint method.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_ode, **kwargs): @@ -614,7 +648,8 @@ def is_ode_method(self): return True -class Testdiffrax_DOP5(TestSolverMethodJax, DiffraxTestBase): +@partial(test_array_backends, array_libraries=["jax"]) +class Testdiffrax_DOP5(TestSolverMethodJAX, DiffraxTestBase): """Tests for diffrax Dopri5 method.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_ode, **kwargs): @@ -633,7 +668,8 @@ def is_ode_method(self): return True -class Testdiffrax_Tsit5(TestSolverMethodJax, DiffraxTestBase): +@partial(test_array_backends, array_libraries=["jax"]) +class Testdiffrax_Tsit5(TestSolverMethodJAX, DiffraxTestBase): """Tests for diffrax Tsit5 method.""" def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_ode, **kwargs): @@ -650,7 +686,3 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_ode, **kwargs): @property def is_ode_method(self): return True - - -# delete abstract classes so unittest doesn't attempt to run them -del TestSolverMethod, TestSolverMethodJax diff --git a/test/dynamics/solvers/test_solver_functions_interface.py b/test/dynamics/solvers/test_solver_functions_interface.py index 6f41c0faa..23392011e 100644 --- a/test/dynamics/solvers/test_solver_functions_interface.py +++ b/test/dynamics/solvers/test_solver_functions_interface.py @@ -29,7 +29,7 @@ results_y_out_of_frame_basis, ) -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import QiskitDynamicsTestCase, JAXTestBase class Testsolve_ode_exceptions(QiskitDynamicsTestCase): @@ -53,7 +53,7 @@ def setUp(self): def test_lmde_method_non_vectorized_lindblad(self): """Test error raising if LMDE method is specified for non-vectorized Lindblad.""" - with self.assertRaisesRegex(QiskitError, "vectorized evaluation"): + with self.assertRaisesRegex(QiskitError, "vectorized=True"): solve_lmde( self.lindblad_model, t_span=[0.0, 1.0], y0=np.diag([1.0, 0.0]), method="scipy_expm" ) @@ -109,7 +109,7 @@ def test_jax_expm_sparse_mode(self): in sparse mode.""" model = GeneratorModel( - static_operator=np.array([[0.0, 1.0], [1.0, 0.0]]), evaluation_mode="sparse" + static_operator=np.array([[0.0, 1.0], [1.0, 0.0]]), array_library="jax_sparse" ) with self.assertRaisesRegex(QiskitError, "jax_expm cannot be used"): @@ -131,7 +131,7 @@ def test_scipy_expm_magnus_order_exception(self): ) -class Testsolve_lmde_exceptionsJAX(QiskitDynamicsTestCase, TestJaxBase): +class Testsolve_lmde_exceptionsJAX(JAXTestBase): """Test solve_lmde exceptions requiring JAX to reach.""" def test_jax_expm_magnus_order_exception(self): @@ -180,14 +180,43 @@ def setUp(self): static_dissipators=[Operator.from_label("Y")], ) - self.vec_lindblad_model = self.lindblad_model.copy() - self.vec_lindblad_model.evaluation_mode = "dense_vectorized" + self.vec_lindblad_model = LindbladModel( + hamiltonian_operators=[Operator.from_label("X")], + hamiltonian_signals=[Signal(1.0, 5.0)], + static_hamiltonian=Operator.from_label("Z"), + static_dissipators=[Operator.from_label("Y")], + vectorized=True, + ) self.frame_op = 1.2 * Operator.from_label("X") - 3.132 * Operator.from_label("Y") _, U = np.linalg.eigh(self.frame_op) self.U = U self.Uadj = self.U.conj().transpose() + self.rf_ham_model = HamiltonianModel( + operators=[Operator.from_label("X")], + signals=[Signal(1.0, 5.0)], + static_operator=Operator.from_label("Z"), + rotating_frame=self.frame_op, + ) + + self.rf_lindblad_model = LindbladModel( + hamiltonian_operators=[Operator.from_label("X")], + hamiltonian_signals=[Signal(1.0, 5.0)], + static_hamiltonian=Operator.from_label("Z"), + static_dissipators=[Operator.from_label("Y")], + rotating_frame=self.frame_op, + ) + + self.rf_vec_lindblad_model = LindbladModel( + hamiltonian_operators=[Operator.from_label("X")], + hamiltonian_signals=[Signal(1.0, 5.0)], + static_hamiltonian=Operator.from_label("Z"), + static_dissipators=[Operator.from_label("Y")], + rotating_frame=self.frame_op, + vectorized=True, + ) + def test_hamiltonian_setup_no_frame(self): """Test functions for Hamiltonian with no frame.""" @@ -210,24 +239,21 @@ def test_hamiltonian_setup_no_frame(self): def test_hamiltonian_setup(self): """Test functions for Hamiltonian with frame.""" - ham_model = self.ham_model.copy() - ham_model.rotating_frame = self.frame_op y0 = np.array([3.43, 1.31]) gen, rhs, new_y0, model_in_frame_basis = setup_generator_model_rhs_y0_in_frame_basis( - ham_model, y0 + self.rf_ham_model, y0 ) # check frame parameters self.assertFalse(model_in_frame_basis) # check that model has been converted to being in frame basis - self.assertTrue(ham_model.in_frame_basis) - self.assertFalse(self.ham_model.in_frame_basis) + self.assertTrue(self.rf_ham_model.in_frame_basis) self.assertAllClose(self.Uadj @ y0, new_y0) t = 231.232 self.ham_model.in_frame_basis = True - self.assertAllClose(gen(t), ham_model(t)) - self.assertAllClose(rhs(t, y0), ham_model(t, y0)) + self.assertAllClose(gen(t), self.rf_ham_model(t)) + self.assertAllClose(rhs(t, y0), self.rf_ham_model(t, y0)) def test_lindblad_setup_no_frame(self): """Test functions for LindbladModel with no frame.""" @@ -246,17 +272,14 @@ def test_lindblad_setup_no_frame(self): def test_lindblad_setup(self): """Test functions for LindbladModel with frame.""" - lindblad_model = self.lindblad_model.copy() - lindblad_model.rotating_frame = self.frame_op - y0 = np.array([[3.43, 1.31], [3.0, 1.23]]) - _, rhs, new_y0, _ = setup_generator_model_rhs_y0_in_frame_basis(lindblad_model, y0) + _, rhs, new_y0, _ = setup_generator_model_rhs_y0_in_frame_basis(self.rf_lindblad_model, y0) # expect nothing to happen self.assertAllClose(self.Uadj @ y0 @ self.U, new_y0) t = 231.232 - self.assertTrue(lindblad_model.in_frame_basis) - self.assertAllClose(rhs(t, y0), lindblad_model(t, y0)) + self.assertTrue(self.rf_lindblad_model.in_frame_basis) + self.assertAllClose(rhs(t, y0), self.rf_lindblad_model(t, y0)) def test_vectorized_lindblad_setup_no_frame(self): """Test functions for vectorized LindbladModel with no frame.""" @@ -276,18 +299,17 @@ def test_vectorized_lindblad_setup_no_frame(self): def test_vectorized_lindblad_setup(self): """Test functions for vectorized LindbladModel with frame.""" - vec_lindblad_model = self.vec_lindblad_model.copy() - vec_lindblad_model.rotating_frame = self.frame_op - y0 = np.array([[3.43, 1.31], [3.0, 1.23]]).flatten() - gen, rhs, new_y0, _ = setup_generator_model_rhs_y0_in_frame_basis(vec_lindblad_model, y0) + gen, rhs, new_y0, _ = setup_generator_model_rhs_y0_in_frame_basis( + self.rf_vec_lindblad_model, y0 + ) # expect nothing to happen self.assertAllClose(np.kron(self.Uadj.conj(), self.Uadj) @ y0, new_y0) t = 231.232 - self.assertTrue(vec_lindblad_model.in_frame_basis) - self.assertAllClose(gen(t), vec_lindblad_model(t)) - self.assertAllClose(rhs(t, y0), vec_lindblad_model(t, y0)) + self.assertTrue(self.rf_vec_lindblad_model.in_frame_basis) + self.assertAllClose(gen(t), self.rf_vec_lindblad_model(t)) + self.assertAllClose(rhs(t, y0), self.rf_vec_lindblad_model(t, y0)) def test_hamiltonian_results_conversion_no_frame(self): """Test hamiltonian results conversion with no frame.""" @@ -305,18 +327,15 @@ def test_hamiltonian_results_conversion_no_frame(self): def test_hamiltonian_results_conversion(self): """Test hamiltonian results conversion.""" - ham_model = self.ham_model.copy() - ham_model.rotating_frame = self.frame_op - # test 1d input results_y = np.array([[1.0, 23.3], [2.32, 1.232]]) - output = results_y_out_of_frame_basis(ham_model, results_y, y0_ndim=1) + output = results_y_out_of_frame_basis(self.rf_ham_model, results_y, y0_ndim=1) expected = [self.U @ y for y in results_y] self.assertAllClose(expected, output) # test 2d input results_y = np.array([[[1.0, 23.3], [23, 231j]], [[2.32, 1.232], [1j, 2.0 + 3j]]]) - output = results_y_out_of_frame_basis(ham_model, results_y, y0_ndim=2) + output = results_y_out_of_frame_basis(self.rf_ham_model, results_y, y0_ndim=2) expected = [self.U @ y for y in results_y] self.assertAllClose(expected, output) @@ -330,11 +349,8 @@ def test_lindblad_results_conversion_no_frame(self): def test_lindblad_results_conversion(self): """Test lindblad results conversion.""" - lindblad_model = self.lindblad_model.copy() - lindblad_model.rotating_frame = self.frame_op - results_y = np.array([[[1.0, 23.3], [23, 231j]], [[2.32, 1.232], [1j, 2.0 + 3j]]]) - output = results_y_out_of_frame_basis(lindblad_model, results_y, y0_ndim=2) + output = results_y_out_of_frame_basis(self.rf_lindblad_model, results_y, y0_ndim=2) expected = [self.U @ y @ self.Uadj for y in results_y] self.assertAllClose(expected, output) @@ -359,13 +375,11 @@ def test_vectorized_lindblad_results_conversion_no_frame(self): def test_vectorized_lindblad_results_conversion(self): """Test vectorized lindblad results conversion.""" - vec_lindblad_model = self.vec_lindblad_model.copy() - vec_lindblad_model.rotating_frame = self.frame_op P = np.kron(self.U.conj(), self.U) # test 1d input results_y = np.array([[1.0, 23.3, 1.23, 0.123], [2.32, 1.232, 1j, 21.2]]) - output = results_y_out_of_frame_basis(vec_lindblad_model, results_y, y0_ndim=1) + output = results_y_out_of_frame_basis(self.rf_vec_lindblad_model, results_y, y0_ndim=1) expected = [P @ y for y in results_y] self.assertAllClose(expected, output) @@ -376,6 +390,6 @@ def test_vectorized_lindblad_results_conversion(self): [[2.32, 1.232], [1j, 2.0 + 3j], [2.32, 2.12314], [334.0, 34.3]], ] ) - output = results_y_out_of_frame_basis(vec_lindblad_model, results_y, y0_ndim=2) + output = results_y_out_of_frame_basis(self.rf_vec_lindblad_model, results_y, y0_ndim=2) expected = [P @ y for y in results_y] self.assertAllClose(expected, output) diff --git a/test/dynamics/solvers/test_solver_utils.py b/test/dynamics/solvers/test_solver_utils.py index 3bf4fe176..fbd1c0eae 100644 --- a/test/dynamics/solvers/test_solver_utils.py +++ b/test/dynamics/solvers/test_solver_utils.py @@ -26,7 +26,7 @@ trim_t_results_jax, ) -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import QiskitDynamicsTestCase, JAXTestBase try: import jax.numpy as jnp @@ -173,7 +173,7 @@ def test_trim_t_results_with_overlap(self): self.assertAllClose(trimmed_obj.y, np.array([[0.0, 1.0], [0.5, 0.5], [1.0, 0.0]])) -class TestTimeArgsHandlingJAX(TestTimeArgsHandling, TestJaxBase): +class TestTimeArgsHandlingJAX(TestTimeArgsHandling, JAXTestBase): """Tests for merge_t_args_jax and trim_t_results_jax functions.""" def test_merge_t_args_with_overlap(self): @@ -219,19 +219,19 @@ def trim_t_results(self, results, t_eval=None): def test_merge_t_args_interval_error(self): """Test output nan if t_eval not in t_span.""" out = self.merge_t_args(t_span=np.array([0.0, 1.0]), t_eval=np.array([1.5])) - self.assertTrue(jnp.isnan(out.data).all()) + self.assertTrue(jnp.isnan(out).all()) self.assertTrue(out.shape == (3,)) def test_merge_t_args_interval_error_backwards(self): """Test output nan if t_eval not in t_span for backwards integration.""" out = self.merge_t_args(t_span=np.array([0.0, -1.0]), t_eval=np.array([-1.5])) - self.assertTrue(jnp.isnan(out.data).all()) + self.assertTrue(jnp.isnan(out).all()) self.assertTrue(out.shape == (3,)) def test_merge_t_args_sort_error(self): """Test output nan if t_eval is not correctly sorted.""" out = self.merge_t_args(t_span=np.array([0.0, 1.0]), t_eval=np.array([0.75, 0.25])) - self.assertTrue(jnp.isnan(out.data).all()) + self.assertTrue(jnp.isnan(out).all()) self.assertTrue(out.shape == (4,)) def test_merge_t_args_sort_error_backwards(self): @@ -239,7 +239,7 @@ def test_merge_t_args_sort_error_backwards(self): backwards integration. """ out = self.merge_t_args(t_span=np.array([0.0, -1.0]), t_eval=np.array([-0.75, -0.25])) - self.assertTrue(jnp.isnan(out.data).all()) + self.assertTrue(jnp.isnan(out).all()) self.assertTrue(out.shape == (4,)) def test_trim_t_results_t0_duplicate(self):