From 7333814751928d86a4410f039ea2894e3f4b45aa Mon Sep 17 00:00:00 2001 From: Kento Ueda <38037695+to24toro@users.noreply.github.com> Date: Tue, 31 Oct 2023 00:45:35 +0900 Subject: [PATCH] Arraylias integration - Signal class - (#269) Co-authored-by: DanPuzzuoli --- docs/tutorials/optimizing_pulse_sequence.rst | 8 +- .../how_to_configure_simulations.rst | 6 +- docs/userguide/how_to_use_jax.rst | 1 - .../how_to_use_pulse_schedule_for_jax_jit.rst | 2 +- docs/userguide/perturbative_solvers.rst | 4 +- qiskit_dynamics/__init__.py | 1 + qiskit_dynamics/arraylias/__init__.py | 8 +- qiskit_dynamics/arraylias/alias.py | 51 +++- .../models/operator_collections.py | 3 +- qiskit_dynamics/pulse/pulse_to_signals.py | 29 +- qiskit_dynamics/signals/signals.py | 279 ++++++++---------- qiskit_dynamics/signals/transfer_functions.py | 15 +- test/dynamics/models/test_generator_model.py | 2 +- test/dynamics/signals/test_signals.py | 253 ++++++++-------- test/dynamics/signals/test_signals_algebra.py | 8 +- .../solvers/test_dyson_magnus_solvers.py | 10 +- test/dynamics/solvers/test_solver_classes.py | 3 +- .../dynamics/solvers/test_solver_functions.py | 4 +- 18 files changed, 352 insertions(+), 335 deletions(-) diff --git a/docs/tutorials/optimizing_pulse_sequence.rst b/docs/tutorials/optimizing_pulse_sequence.rst index 6ad577253..757465c3a 100644 --- a/docs/tutorials/optimizing_pulse_sequence.rst +++ b/docs/tutorials/optimizing_pulse_sequence.rst @@ -115,9 +115,10 @@ example of one. .. jupyter-execute:: from qiskit_dynamics import DiscreteSignal - from qiskit_dynamics.array import Array from qiskit_dynamics.signals import Convolution + import jax.numpy as jnp + # define convolution filter def gaus(t): sigma = 15 @@ -128,13 +129,12 @@ example of one. # define function mapping parameters to signals def signal_mapping(params): - samples = Array(params) # map samples into [-1, 1] - bounded_samples = np.arctan(samples) / (np.pi / 2) + bounded_samples = jnp.arctan(params) / (np.pi / 2) # pad with 0 at beginning - padded_samples = np.append(Array([0], dtype=complex), bounded_samples) + padded_samples = jnp.append(jnp.array([0], dtype=complex), bounded_samples) # apply filter output_signal = convolution(DiscreteSignal(dt=1., samples=padded_samples)) diff --git a/docs/userguide/how_to_configure_simulations.rst b/docs/userguide/how_to_configure_simulations.rst index 64b0babb6..21ffa965b 100644 --- a/docs/userguide/how_to_configure_simulations.rst +++ b/docs/userguide/how_to_configure_simulations.rst @@ -246,7 +246,7 @@ further highlight the benefits of the sparse representation. static_hamiltonian = 2 * np.pi * v * N + np.pi * anharm * N * (N - np.eye(dim)) drive_hamiltonian = 2 * np.pi * r * (a + adag) - drive_signal = Signal(Array(1.), carrier_freq=v) + drive_signal = Signal(1., carrier_freq=v) y0 = np.zeros(dim, dtype=complex) y0[1] = 1. @@ -266,7 +266,7 @@ amplitude, and just-in-time compile it using JAX. ) def dense_func(amp): - drive_signal = Signal(Array(amp), carrier_freq=v) + drive_signal = Signal(amp, carrier_freq=v) res = solver.solve( t_span=[0., T], y0=y0, @@ -292,7 +292,7 @@ diagonal, but we explicitly highlight the need for this. evaluation_mode='sparse') def sparse_func(amp): - drive_signal = Signal(Array(amp), carrier_freq=v) + drive_signal = Signal(amp, carrier_freq=v) res = sparse_solver.solve( t_span=[0., T], y0=y0, diff --git a/docs/userguide/how_to_use_jax.rst b/docs/userguide/how_to_use_jax.rst index 905013e0a..4efec44e6 100644 --- a/docs/userguide/how_to_use_jax.rst +++ b/docs/userguide/how_to_use_jax.rst @@ -138,7 +138,6 @@ before setting the signals, to ensure the simulation function remains pure. def sim_function(amp): # define a constant signal - amp = Array(amp) signals = [Signal(amp, carrier_freq=w)] # simulate and return results diff --git a/docs/userguide/how_to_use_pulse_schedule_for_jax_jit.rst b/docs/userguide/how_to_use_pulse_schedule_for_jax_jit.rst index 35a66c717..01914f19f 100644 --- a/docs/userguide/how_to_use_pulse_schedule_for_jax_jit.rst +++ b/docs/userguide/how_to_use_pulse_schedule_for_jax_jit.rst @@ -137,6 +137,6 @@ JAX-compiled (or more generally, JAX-transformed). # convert from a pulse schedule to a list of signals converter = InstructionToSignals(dt, carriers={"d0": w}) - return converter.get_signals(schedule)[0].samples.data + return converter.get_signals(schedule)[0].samples jax.jit(jit_func)(0.4) diff --git a/docs/userguide/perturbative_solvers.rst b/docs/userguide/perturbative_solvers.rst index df9c15417..a1d388bbb 100644 --- a/docs/userguide/perturbative_solvers.rst +++ b/docs/userguide/perturbative_solvers.rst @@ -179,7 +179,7 @@ these functions gives a sense of the speeds attainable by these solvers. """For a given envelope amplitude, simulate the final unitary using the Dyson solver. """ - drive_signal = Signal(lambda t: Array(amp) * envelope_func(t), carrier_freq=v) + drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v) return dyson_solver.solve( signals=[drive_signal], y0=np.eye(dim, dtype=complex), @@ -220,7 +220,7 @@ accuracy and simulation speed. # specify tolerance as an argument to run the simulation at different tolerances def ode_sim(amp, tol): - drive_signal = Signal(lambda t: Array(amp) * envelope_func(t), carrier_freq=v) + drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v) res = solver.solve( t_span=[0., int(T // dt) * dt], y0=np.eye(dim, dtype=complex), diff --git a/qiskit_dynamics/__init__.py b/qiskit_dynamics/__init__.py index 1e7ee2bd6..b707d440d 100644 --- a/qiskit_dynamics/__init__.py +++ b/qiskit_dynamics/__init__.py @@ -28,6 +28,7 @@ DYNAMICS_SCIPY_ALIAS, DYNAMICS_NUMPY, DYNAMICS_SCIPY, + ArrayLike, ) from .models.rotating_frame import RotatingFrame diff --git a/qiskit_dynamics/arraylias/__init__.py b/qiskit_dynamics/arraylias/__init__.py index 66102606e..4cd406e3b 100644 --- a/qiskit_dynamics/arraylias/__init__.py +++ b/qiskit_dynamics/arraylias/__init__.py @@ -22,4 +22,10 @@ Module for Qiskit Dynamics global NumPy and SciPy aliases. """ -from .alias import DYNAMICS_NUMPY_ALIAS, DYNAMICS_SCIPY_ALIAS, DYNAMICS_NUMPY, DYNAMICS_SCIPY +from .alias import ( + DYNAMICS_NUMPY_ALIAS, + DYNAMICS_SCIPY_ALIAS, + DYNAMICS_NUMPY, + DYNAMICS_SCIPY, + ArrayLike, +) diff --git a/qiskit_dynamics/arraylias/alias.py b/qiskit_dynamics/arraylias/alias.py index f7d2326bb..6115d1e72 100644 --- a/qiskit_dynamics/arraylias/alias.py +++ b/qiskit_dynamics/arraylias/alias.py @@ -20,6 +20,8 @@ from arraylias import numpy_alias, scipy_alias +from qiskit import QiskitError + from qiskit_dynamics.array import Array # global NumPy and SciPy aliases @@ -34,4 +36,51 @@ DYNAMICS_SCIPY = DYNAMICS_SCIPY_ALIAS() -ArrayLike = Union[DYNAMICS_NUMPY_ALIAS.registered_types()] +ArrayLike = Union[Union[DYNAMICS_NUMPY_ALIAS.registered_types()], list] + + +def _preferred_lib(*args, **kwargs): + """Given a list of args and kwargs with potentially mixed array types, determine the appropriate + library to dispatch to. + + For each argument, DYNAMICS_NUMPY_ALIAS.infer_libs is called to infer the library. If all are + "numpy", then it returns "numpy", and if any are "jax", it returns "jax". + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + Returns: + str + Raises: + QiskitError: if none of the rules apply. + """ + args = list(args) + list(kwargs.values()) + if len(args) == 1: + return DYNAMICS_NUMPY_ALIAS.infer_libs(args[0]) + + lib0 = DYNAMICS_NUMPY_ALIAS.infer_libs(args[0])[0] + lib1 = _preferred_lib(args[1:])[0] + + if lib0 == "numpy" and lib1 == "numpy": + return "numpy" + elif lib0 == "jax" or lib1 == "jax": + return "jax" + + raise QiskitError("_preferred_lib could not resolve preferred library.") + + +def _numpy_multi_dispatch(*args, path, **kwargs): + """Multiple dispatching for NumPy. + + Given *args and **kwargs, dispatch the function specified by path, to the array library + specified by _preferred_lib. + + Args: + *args: Positional arguments to pass to function specified by path. + path: Path in numpy module structure. + **kwargs: Keyword arguments to pass to function specified by path. + Returns: + Result of evaluating the function at path on the arguments using the preferred library. + """ + lib = _preferred_lib(*args, **kwargs) + return DYNAMICS_NUMPY_ALIAS(like=lib, path=path)(*args, **kwargs) diff --git a/qiskit_dynamics/models/operator_collections.py b/qiskit_dynamics/models/operator_collections.py index d7863d7a5..d85fb539d 100644 --- a/qiskit_dynamics/models/operator_collections.py +++ b/qiskit_dynamics/models/operator_collections.py @@ -20,6 +20,7 @@ from qiskit import QiskitError from qiskit.quantum_info.operators.operator import Operator +from qiskit_dynamics.arraylias.alias import _numpy_multi_dispatch from qiskit_dynamics.array import Array, wrap from qiskit_dynamics.type_utils import to_array, to_csr, to_BCOO, vec_commutator, vec_dissipator @@ -1357,7 +1358,7 @@ def concatenate_signals( ) -> Array: """Concatenate hamiltonian and linblad signals.""" if self._hamiltonian_operators is not None and self._dissipator_operators is not None: - return np.append(ham_sig_vals, dis_sig_vals, axis=-1) + return _numpy_multi_dispatch(ham_sig_vals, dis_sig_vals, path="append", axis=-1) if self._hamiltonian_operators is not None and self._dissipator_operators is None: return ham_sig_vals if self._hamiltonian_operators is None and self._dissipator_operators is not None: diff --git a/qiskit_dynamics/pulse/pulse_to_signals.py b/qiskit_dynamics/pulse/pulse_to_signals.py index 9ac451295..39968ba2d 100644 --- a/qiskit_dynamics/pulse/pulse_to_signals.py +++ b/qiskit_dynamics/pulse/pulse_to_signals.py @@ -39,6 +39,9 @@ from qiskit import QiskitError from qiskit_dynamics.array import Array +from qiskit_dynamics import DYNAMICS_NUMPY as unp +from qiskit_dynamics import ArrayLike + from qiskit_dynamics.signals import DiscreteSignal try: @@ -186,12 +189,10 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]: # build sample array to append to signal times = self._dt * (start_sample + np.arange(len(inst_samples))) - samples = inst_samples * np.exp( - Array( - 2.0j * np.pi * frequency_shifts[chan] * times - + 1.0j * phases[chan] - + 2.0j * np.pi * phase_accumulations[chan] - ) + samples = inst_samples * unp.exp( + 2.0j * np.pi * frequency_shifts[chan] * times + + 1.0j * phases[chan] + + 2.0j * np.pi * phase_accumulations[chan] ) signals[chan].add_samples(start_sample, samples) @@ -202,7 +203,7 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]: phases[chan] = inst.phase if isinstance(inst, ShiftFrequency): - frequency_shifts[chan] = frequency_shifts[chan] + Array(inst.frequency) + frequency_shifts[chan] = frequency_shifts[chan] + inst.frequency phase_accumulations[chan] = ( phase_accumulations[chan] - inst.frequency * start_sample * self._dt ) @@ -273,7 +274,7 @@ def get_awg_signals( new_freq = sig.carrier_freq + if_modulation samples_i = sig.samples - samples_q = np.imag(samples_i) - 1.0j * np.real(samples_i) + samples_q = unp.imag(samples_i) - 1.0j * unp.real(samples_i) sig_i = DiscreteSignal( sig.dt, @@ -325,7 +326,7 @@ def _get_channel(self, channel_name: str): ) from error -def get_samples(pulse: SymbolicPulse) -> np.ndarray: +def get_samples(pulse: SymbolicPulse) -> ArrayLike: """Return samples filled according to the formula that the pulse represents and the parameter values it contains. @@ -349,8 +350,8 @@ def get_samples(pulse: SymbolicPulse) -> np.ndarray: args = [] for symbol in sorted(envelope.free_symbols, key=lambda s: s.name): if symbol.name == "t": - times = Array(np.arange(0, pulse_params["duration"]) + 1 / 2) - args.insert(0, times.data) + times = unp.arange(0, pulse_params["duration"]) + 1 / 2 + args.insert(0, times) continue try: args.append(pulse_params[symbol.name]) @@ -381,11 +382,11 @@ def _lru_cache_expr(expr: sym.Expr, backend) -> Callable: return sym.lambdify(params, expr, modules=backend) -def _nyquist_warn(frequency_shift: Array, dt: float, channel: str): +def _nyquist_warn(frequency_shift: ArrayLike, dt: float, channel: str): """Raise a warning if the frequency shift is above the Nyquist frequency given by ``dt``.""" - if ( - Array(frequency_shift).backend != "jax" or not isinstance(jnp.array(0), jax.core.Tracer) + isinstance(frequency_shift, (int, float, list, np.ndarray)) + or not isinstance(jnp.array(0), jax.core.Tracer) ) and np.abs(frequency_shift) > 0.5 / dt: warn( "Due to SetFrequency and ShiftFrequency instructions, the digital carrier frequency " diff --git a/qiskit_dynamics/signals/signals.py b/qiskit_dynamics/signals/signals.py index 493702227..acd9d301b 100644 --- a/qiskit_dynamics/signals/signals.py +++ b/qiskit_dynamics/signals/signals.py @@ -2,7 +2,7 @@ # This code is part of Qiskit. # -# (C) Copyright IBM 2021. +# (C) Copyright IBM 2021, 2023. # # This code is licensed under the Apache License, Version 2.0. You may # obtain a copy of this license in the LICENSE.txt file in the root directory @@ -24,13 +24,11 @@ import numpy as np from matplotlib import pyplot as plt -try: - import jax.numpy as jnp -except ImportError: - pass - from qiskit import QiskitError -from qiskit_dynamics.array import Array +from qiskit_dynamics import ArrayLike +from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS as numpy_alias +from qiskit_dynamics import DYNAMICS_NUMPY as unp +from qiskit_dynamics.arraylias.alias import _numpy_multi_dispatch, _preferred_lib class Signal: @@ -72,9 +70,9 @@ class Signal: def __init__( self, - envelope: Union[Callable, complex, float, int, Array], - carrier_freq: Union[float, List, Array] = 0.0, - phase: Union[float, List, Array] = 0.0, + envelope: Union[Callable, ArrayLike], + carrier_freq: ArrayLike = 0.0, + phase: ArrayLike = 0.0, name: Optional[str] = None, ): """ @@ -91,10 +89,9 @@ def __init__( self._name = name self._is_constant = False - if isinstance(envelope, (complex, float, int)): - envelope = Array(complex(envelope)) - - if isinstance(envelope, Array): + if not callable(envelope): + # if not callable, we assume a constant + envelope = unp.asarray(envelope) # if envelope is constant and the carrier is zero, this is a constant signal try: # try block is for catching JAX tracer errors @@ -103,15 +100,9 @@ def __init__( except Exception: # pylint: disable=broad-except pass - if envelope.backend == "jax": - self._envelope = lambda t: envelope * jnp.ones_like(Array(t).data) - else: - self._envelope = lambda t: envelope * np.ones_like(t) - elif callable(envelope): - if Array.default_backend() == "jax": - self._envelope = lambda t: Array(envelope(t)) - else: - self._envelope = envelope + self._envelope = lambda t: envelope * unp.ones_like(t) + else: + self._envelope = envelope # set carrier and phase self.carrier_freq = carrier_freq @@ -128,45 +119,41 @@ def is_constant(self) -> bool: return self._is_constant @property - def carrier_freq(self) -> Array: + def carrier_freq(self) -> ArrayLike: """The carrier frequency of the signal.""" return self._carrier_freq @carrier_freq.setter - def carrier_freq(self, carrier_freq: Union[float, list, Array]): + def carrier_freq(self, carrier_freq: ArrayLike): """Carrier frequency setter. List handling is to support subclasses storing a list of frequencies.""" - if type(carrier_freq) == list: - carrier_freq = [Array(entry).data for entry in carrier_freq] - self._carrier_freq = Array(carrier_freq) + self._carrier_freq = unp.asarray(carrier_freq) self._carrier_arg = 1j * 2 * np.pi * self._carrier_freq @property - def phase(self) -> Array: + def phase(self) -> ArrayLike: """The phase of the signal.""" return self._phase @phase.setter - def phase(self, phase: Union[float, list, Array]): + def phase(self, phase: ArrayLike): """Phase setter. List handling is to support subclasses storing a list of phases.""" - if type(phase) == list: - phase = [Array(entry).data for entry in phase] - self._phase = Array(phase) + self._phase = unp.asarray(phase) self._phase_arg = 1j * self._phase - def envelope(self, t: Union[float, np.array, Array]) -> Union[complex, np.array, Array]: + def envelope(self, t: ArrayLike) -> ArrayLike: """Vectorized evaluation of the envelope at time t.""" return self._envelope(t) - def complex_value(self, t: Union[float, np.array, Array]) -> Union[complex, np.array, Array]: + def complex_value(self, t: ArrayLike) -> ArrayLike: """Vectorized evaluation of the complex value at time t.""" arg = self._carrier_arg * t + self._phase_arg - return self.envelope(t) * np.exp(arg) + return self.envelope(t) * unp.exp(arg) - def __call__(self, t: Union[float, np.array, Array]) -> Union[complex, np.array, Array]: + def __call__(self, t: ArrayLike) -> ArrayLike: """Vectorized evaluation of the signal at time(s) t.""" - return np.real(self.complex_value(t)) + return unp.real(self.complex_value(t)) def __str__(self) -> str: """Return string representation.""" @@ -203,7 +190,7 @@ def conjugate(self): """Return a new signal whose complex value is the complex conjugate of this one.""" def conj_env(t): - return np.conjugate(self.envelope(t)) + return unp.conjugate(self.envelope(t)) return Signal(conj_env, -self.carrier_freq, -self.phase) @@ -283,10 +270,10 @@ class DiscreteSignal(Signal): def __init__( self, dt: float, - samples: Union[Array, List], + samples: ArrayLike, start_time: float = 0.0, - carrier_freq: Union[float, List, Array] = 0.0, - phase: Union[float, List, Array] = 0.0, + carrier_freq: ArrayLike = 0.0, + phase: ArrayLike = 0.0, name: str = None, ): """Initialize a piecewise constant signal. @@ -303,38 +290,27 @@ def __init__( """ self._dt = dt - samples = Array(samples) + samples = unp.asarray(samples) if len(samples) == 0: - zero_pad = np.array([0]) + zero_pad = np.asarray([0]) else: - zero_pad = np.expand_dims(np.zeros_like(Array(samples[0])), 0) - self._padded_samples = np.append(samples, zero_pad, axis=0) + zero_pad = unp.expand_dims(unp.zeros_like(samples[0]), 0) + self._padded_samples = unp.append(samples, zero_pad, axis=0) self._start_time = start_time # define internal envelope function - if samples.backend == "jax": - - def envelope(t): - t = Array(t).data - idx = jnp.clip( - jnp.array((t - self._start_time) // self._dt, dtype=int), - -1, - len(self.samples), - ) - return self._padded_samples[idx] - - else: - - def envelope(t): - t = Array(t).data - idx = np.clip( - np.array((t - self._start_time) // self._dt, dtype=int), - -1, - len(self.samples), - ) - return self._padded_samples[idx] + def envelope(t): + t = unp.asarray(t) + idx = unp.clip( + unp.array((t - self._start_time) // self._dt, dtype=int), + -1, + len(self.samples), + ) + return numpy_alias(like=_preferred_lib(self._padded_samples, idx)).asarray( + self._padded_samples + )[idx] Signal.__init__(self, envelope=envelope, carrier_freq=carrier_freq, phase=phase, name=name) @@ -410,12 +386,12 @@ def dt(self) -> float: return self._dt @property - def samples(self) -> Array: + def samples(self) -> ArrayLike: """ Returns: samples: the samples of the piecewise constant signal. """ - return Array(self._padded_samples[:-1]) + return self._padded_samples[:-1] @property def start_time(self) -> float: @@ -428,7 +404,7 @@ def start_time(self) -> float: def conjugate(self): return self.__class__( dt=self._dt, - samples=np.conjugate(self.samples), + samples=unp.conjugate(self.samples), start_time=self._start_time, carrier_freq=-self.carrier_freq, phase=-self.phase, @@ -445,7 +421,7 @@ def add_samples(self, start_sample: int, samples: List): Raises: QiskitError: If start_sample is less than the current length of samples. """ - samples = Array(samples) + samples = unp.asarray(samples) if len(samples) < 1: return @@ -453,16 +429,16 @@ def add_samples(self, start_sample: int, samples: List): if start_sample < len(self.samples): raise QiskitError("Samples can only be added afer the last sample.") - zero_pad = np.expand_dims(np.zeros_like(Array(samples[0])), 0) + zero_pad = unp.expand_dims(unp.zeros_like(samples[0]), 0) new_samples = self.samples if len(self.samples) < start_sample: - new_samples = np.append( - new_samples, np.repeat(zero_pad, start_sample - len(self.samples)) + new_samples = unp.append( + new_samples, unp.repeat(zero_pad, start_sample - len(self.samples)) ) - new_samples = np.append(new_samples, samples) - self._padded_samples = np.append(new_samples, zero_pad, axis=0) + new_samples = _numpy_multi_dispatch(new_samples, samples, path="append") + self._padded_samples = unp.append(new_samples, zero_pad, axis=0) def __str__(self) -> str: """Return string representation.""" @@ -494,12 +470,10 @@ def __len__(self): """Number of components.""" return len(self.components) - def __getitem__( - self, idx: Union[int, List, np.array, slice] - ) -> Union[Signal, "SignalCollection"]: + def __getitem__(self, idx: Union[ArrayLike, slice]) -> Union[Signal, "SignalCollection"]: """Get item with NumPy-style subscripting, as if this class were a 1d array.""" - if type(idx) == np.ndarray and idx.ndim > 0: + if type(idx) != slice and unp.asarray(idx).ndim > 0: idx = list(idx) # get a list of the subcomponents @@ -543,9 +517,9 @@ class SignalSum(SignalCollection, Signal): - ``__call__`` evaluates the sum. - ``complex_value`` evaluates the sum of the complex values of the individual summands. - Attributes ``carrier_freq`` and ``phase`` here correspond to an ``Array`` of + Attributes ``carrier_freq`` and ``phase`` here correspond to an ``ArrayLike`` of frequencies/phases for each term in the sum, and the ``envelope`` method returns an - ``Array`` of the envelopes for each summand. + ``ArrayLike`` of the envelopes for each summand. Internally, the signals are stored as a list in the ``components`` attribute, which can be accessed via direct subscripting of the object. @@ -572,28 +546,20 @@ def __init__(self, *signals, name: Optional[str] = None): components += sig.components elif isinstance(sig, Signal): components.append(sig) - elif isinstance(sig, (int, float, complex)) or ( - isinstance(sig, Array) and sig.ndim == 0 - ): - components.append(Signal(sig)) else: - raise QiskitError( - "Components of a SignalSum must be instances of a Signal subclass." - ) + try: + if unp.asarray(sig).ndim == 0: + components.append(Signal(sig)) + except QiskitError as qe: + raise QiskitError( + "Components of a SignalSum must be instances " + "of a Signal subclass or a scalar." + ) from qe SignalCollection.__init__(self, components) - # set up routine for evaluating envelopes if jax - if Array.default_backend() == "jax": - jax_arraylist_eval = array_funclist_evaluate([sig.envelope for sig in self.components]) - - def envelope(t): - return np.moveaxis(jax_arraylist_eval(t), 0, -1) - - else: - - def envelope(t): - return np.moveaxis([sig.envelope(t) for sig in self.components], 0, -1) + def envelope(t): + return unp.moveaxis(unp.asarray([sig.envelope(t) for sig in self.components]), 0, -1) carrier_freqs = [] for sig in self.components: @@ -607,12 +573,10 @@ def envelope(t): self, envelope=envelope, carrier_freq=carrier_freqs, phase=phases, name=name ) - def complex_value(self, t: Union[float, np.array, Array]) -> Union[complex, np.array, Array]: + def complex_value(self, t: ArrayLike) -> ArrayLike: """Return the sum of the complex values of each component.""" - if Array.default_backend() == "jax": - t = Array(t) - exp_phases = np.exp(np.expand_dims(t, -1) * self._carrier_arg + self._phase_arg) - return np.sum(self.envelope(t) * exp_phases, axis=-1) + exp_phases = unp.exp(unp.expand_dims(t, -1) * self._carrier_arg + self._phase_arg) + return unp.sum(self.envelope(t) * exp_phases, axis=-1) def __str__(self): if self.name is not None: @@ -637,12 +601,12 @@ def flatten(self) -> Signal: elif len(self) == 1: return self.components[0] - ave_freq = np.sum(self.carrier_freq) / len(self) + ave_freq = unp.sum(self.carrier_freq) / len(self) shifted_arg = self._carrier_arg - (1j * 2 * np.pi * ave_freq) def merged_env(t): - exp_phases = np.exp(np.expand_dims(Array(t), -1) * shifted_arg + self._phase_arg) - return np.sum(self.envelope(t) * exp_phases, axis=-1) + exp_phases = unp.exp(unp.expand_dims(t, -1) * shifted_arg + self._phase_arg) + return unp.sum(self.envelope(t) * exp_phases, axis=-1) return Signal(envelope=merged_env, carrier_freq=ave_freq, name=str(self)) @@ -655,10 +619,10 @@ class DiscreteSignalSum(DiscreteSignal, SignalSum): def __init__( self, dt: float, - samples: Union[List, Array], + samples: ArrayLike, start_time: float = 0.0, - carrier_freq: Union[List, np.array, Array] = None, - phase: Union[List, np.array, Array] = None, + carrier_freq: ArrayLike = None, + phase: ArrayLike = None, name: str = None, ): r"""Directly initialize a ``DiscreteSignalSum``\. Samples of all terms in the @@ -675,7 +639,7 @@ def __init__( name: name of the signal. """ - samples = Array(samples) + samples = unp.asarray(samples) if carrier_freq is None: carrier_freq = np.zeros(samples.shape[-1], dtype=float) @@ -750,7 +714,7 @@ def from_SignalSum( if sample_carrier: freq = 0.0 * freq - exp_phases = np.exp(np.expand_dims(Array(times), -1) * signal_sum._carrier_arg) + exp_phases = unp.exp(unp.expand_dims(times, -1) * signal_sum._carrier_arg) samples = signal_sum.envelope(times) * exp_phases else: samples = signal_sum.envelope(times) @@ -778,7 +742,7 @@ def __str__(self): return default_str - def __getitem__(self, idx: Union[int, List, np.array, slice]) -> Signal: + def __getitem__(self, idx: ArrayLike) -> Signal: """Enables numpy-style subscripting, as if this class were a 1d array.""" if type(idx) == int and idx >= len(self): @@ -789,13 +753,13 @@ def __getitem__(self, idx: Union[int, List, np.array, slice]) -> Signal: phases = self.phase[idx] if samples.ndim == 1: - samples = Array([samples]) + samples = unp.asarray([samples]) if carrier_freqs.ndim == 0: - carrier_freqs = Array([carrier_freqs]) + carrier_freqs = unp.asarray([carrier_freqs]) if phases.ndim == 0: - phases = Array([phases]) + phases = unp.asarray([phases]) if len(samples) == 1: return DiscreteSignal( @@ -827,22 +791,18 @@ def __init__(self, signal_list: List[Signal]): super().__init__(signal_list) # setup complex value and full signal evaluation - if Array.default_backend() == "jax": - self._eval_complex_value = array_funclist_evaluate( - [sig.complex_value for sig in self.components] - ) - self._eval_signals = array_funclist_evaluate(self.components) - else: - self._eval_complex_value = lambda t: [sig.complex_value(t) for sig in self.components] - self._eval_signals = lambda t: [sig(t) for sig in self.components] + self._eval_complex_value = lambda t: unp.asarray( + [sig.complex_value(t) for sig in self.components] + ) + self._eval_signals = lambda t: unp.asarray([sig(t) for sig in self.components]) - def complex_value(self, t: Union[float, np.array, Array]) -> Union[np.array, Array]: + def complex_value(self, t: ArrayLike) -> ArrayLike: """Vectorized evaluation of complex value of components.""" - return np.moveaxis(self._eval_complex_value(t), 0, -1) + return unp.moveaxis(self._eval_complex_value(t), 0, -1) - def __call__(self, t: Union[float, np.array, Array]) -> Union[np.array, Array]: + def __call__(self, t: ArrayLike) -> ArrayLike: """Vectorized evaluation of all components.""" - return np.moveaxis(self._eval_signals(t), 0, -1) + return unp.moveaxis(self._eval_signals(t), 0, -1) def flatten(self) -> "SignalList": """Return a ``SignalList`` with each component flattened.""" @@ -856,8 +816,8 @@ def flatten(self) -> "SignalList": return SignalList(flattened_list) @property - def drift(self) -> Array: - r"""Return the drift ``Array``\, i.e. return an ``Array`` whose entries are the sum + def drift(self) -> ArrayLike: + r"""Return the drift ``ArrayLike``\, i.e. return an ``ArrayLike`` whose entries are the sum of the constant parts of the corresponding component of this ``SignalList``\. """ @@ -870,11 +830,11 @@ def drift(self) -> Array: for term in sig_entry: if term.is_constant: - val += Array(term(0.0)).data + val += term(0.0) drift_array.append(val) - return Array(drift_array) + return unp.asarray(drift_array) def signal_add(sig1: Signal, sig2: Signal) -> SignalSum: @@ -895,9 +855,11 @@ def signal_add(sig1: Signal, sig2: Signal) -> SignalSum: and sig1.start_time == sig2.start_time and sig1.duration == sig2.duration ): - samples = np.append(sig1.samples, sig2.samples, axis=1) - carrier_freq = np.append(sig1.carrier_freq, sig2.carrier_freq) - phase = np.append(sig1.phase, sig2.phase) + samples = _numpy_multi_dispatch(sig1.samples, sig2.samples, axis=1, path="append") + carrier_freq = _numpy_multi_dispatch( + sig1.carrier_freq, sig2.carrier_freq, path="append" + ) + phase = _numpy_multi_dispatch(sig1.phase, sig2.phase, path="append") return DiscreteSignalSum( dt=sig1.dt, samples=samples, @@ -951,29 +913,33 @@ def signal_multiply(sig1: Signal, sig2: Signal) -> SignalSum: ): # this vectorized operation produces a 2d array whose columns are the products of # the original columns - new_samples = Array( + new_samples = unp.asarray( 0.5 * (sig1.samples[:, :, None] * sig2.samples[:, None, :]).reshape( (sig1.samples.shape[0], sig1.samples.shape[1] * sig2.samples.shape[1]), order="C", ) ) - new_samples_conj = Array( + new_samples_conj = unp.asarray( 0.5 * (sig1.samples[:, :, None] * sig2.samples[:, None, :].conj()).reshape( (sig1.samples.shape[0], sig1.samples.shape[1] * sig2.samples.shape[1]), order="C", ) ) - samples = np.append(new_samples, new_samples_conj, axis=1) + samples = _numpy_multi_dispatch(new_samples, new_samples_conj, axis=1, path="append") new_freqs = sig1.carrier_freq + sig2.carrier_freq new_freqs_conj = sig1.carrier_freq - sig2.carrier_freq - freqs = np.append(Array(new_freqs), Array(new_freqs_conj)) + freqs = _numpy_multi_dispatch( + unp.asarray(new_freqs), unp.asarray(new_freqs_conj), path="append" + ) new_phases = sig1.phase + sig2.phase new_phases_conj = sig1.phase - sig2.phase - phases = np.append(Array(new_phases), Array(new_phases_conj)) + phases = _numpy_multi_dispatch( + unp.asarray(new_phases), unp.asarray(new_phases_conj), path="append" + ) return DiscreteSignalSum( dt=sig1.dt, @@ -1059,7 +1025,7 @@ def new_env(t): ) pwc2 = DiscreteSignal( dt=sig2.dt, - samples=0.5 * sig1.samples * np.conjugate(sig2.samples), + samples=0.5 * sig1.samples * unp.conjugate(sig2.samples), start_time=sig2.start_time, carrier_freq=sig1.carrier_freq - sig2.carrier_freq, phase=sig1.phase - sig2.phase, @@ -1071,7 +1037,7 @@ def new_env1(t): return 0.5 * sig1.envelope(t) * sig2.envelope(t) def new_env2(t): - return 0.5 * sig1.envelope(t) * np.conjugate(sig2.envelope(t)) + return 0.5 * sig1.envelope(t) * unp.conjugate(sig2.envelope(t)) prod1 = Signal( envelope=new_env1, @@ -1116,7 +1082,7 @@ def sort_signals(sig1: Signal, sig2: Signal) -> Tuple[Signal, Signal]: return sig1, sig2 -def to_SignalSum(sig: Union[int, float, complex, Array, Signal]) -> SignalSum: +def to_SignalSum(sig: Union[ArrayLike, Signal]) -> SignalSum: r"""Convert the input to a SignalSum according to: - If it is already a ``SignalSum``\, do nothing. @@ -1134,19 +1100,21 @@ def to_SignalSum(sig: Union[int, float, complex, Array, Signal]) -> SignalSum: QiskitError: If the input type is incompatible with SignalSum. """ - if isinstance(sig, (int, float, complex)) or (isinstance(sig, Array) and sig.ndim == 0): + if isinstance(sig, (int, float, complex)) or ( + not isinstance(sig, (int, float, complex, list, Signal)) and sig.ndim == 0 + ): return SignalSum(Signal(sig)) elif isinstance(sig, DiscreteSignal) and not isinstance(sig, DiscreteSignalSum): - if Array(sig.samples.data).shape == (0,): - new_samples = Array([sig.samples.data]) + if sig.samples.shape == (0,): + new_samples = unp.asarray([sig.samples]) else: - new_samples = Array([sig.samples.data]).transpose(1, 0) + new_samples = unp.asarray([sig.samples]).transpose(1, 0) return DiscreteSignalSum( dt=sig.dt, samples=new_samples, start_time=sig.start_time, - carrier_freq=Array([sig.carrier_freq.data]), - phase=Array([sig.phase.data]), + carrier_freq=unp.asarray([sig.carrier_freq]), + phase=unp.asarray([sig.phase]), ) elif isinstance(sig, Signal) and not isinstance(sig, SignalSum): return SignalSum(sig) @@ -1154,14 +1122,3 @@ def to_SignalSum(sig: Union[int, float, complex, Array, Signal]) -> SignalSum: return sig raise QiskitError("Input type incompatible with SignalSum.") - - -def array_funclist_evaluate(func_list: List[Callable]) -> Callable: - """Utility for evaluating a list of functions in a way that respects Arrays. - Currently relevant for JAX evaluation. - """ - - def eval_func(t): - return Array([Array(func(t)).data for func in func_list]) - - return eval_func diff --git a/qiskit_dynamics/signals/transfer_functions.py b/qiskit_dynamics/signals/transfer_functions.py index f9b137a5a..da9188b65 100644 --- a/qiskit_dynamics/signals/transfer_functions.py +++ b/qiskit_dynamics/signals/transfer_functions.py @@ -22,7 +22,8 @@ import numpy as np from qiskit import QiskitError -from qiskit_dynamics.array import Array +from qiskit_dynamics.arraylias import DYNAMICS_NUMPY as unp +from qiskit_dynamics.arraylias.alias import _numpy_multi_dispatch from .signals import Signal, DiscreteSignal @@ -117,11 +118,11 @@ def _apply(self, signal: Signal) -> Signal: if isinstance(signal, DiscreteSignal): # Perform a discrete time convolution. dt = signal.dt - func_samples = Array([self._func(dt * i) for i in range(signal.duration)]) - func_samples = func_samples / sum(func_samples) - sig_samples = signal(dt * np.arange(signal.duration)) + func_samples = unp.asarray([self._func(dt * i) for i in range(signal.duration)]) + func_samples = func_samples / unp.sum(func_samples) + sig_samples = signal(dt * unp.arange(signal.duration)) - convoluted_samples = list(np.convolve(func_samples, sig_samples)) + convoluted_samples = _numpy_multi_dispatch(func_samples, sig_samples, path="convolve") return DiscreteSignal(dt, convoluted_samples, carrier_freq=0.0, phase=0.0) else: @@ -233,8 +234,8 @@ def _apply(self, si: Signal, sq: Signal) -> Signal: def mixer_func(t): """Function of the IQ mixer.""" - osc_i = np.cos(wp * t + phi_i) + np.cos(wm * t + phi_i) - osc_q = np.cos(wp * t + phi_q - np.pi / 2) + np.cos(wm * t + phi_q + np.pi / 2) + osc_i = unp.cos(wp * t + phi_i) + unp.cos(wm * t + phi_i) + osc_q = unp.cos(wp * t + phi_q - np.pi / 2) + unp.cos(wm * t + phi_q + np.pi / 2) return si.envelope(t) * osc_i / 2 + sq.envelope(t) * osc_q / 2 return Signal(mixer_func, carrier_freq=0, phase=0) diff --git a/test/dynamics/models/test_generator_model.py b/test/dynamics/models/test_generator_model.py index b387ca311..5d1b6bb93 100644 --- a/test/dynamics/models/test_generator_model.py +++ b/test/dynamics/models/test_generator_model.py @@ -702,7 +702,7 @@ def test_jit_grad(self): def func(a): model_copy = model.copy() - model_copy.signals = [Signal(Array(a))] + model_copy.signals = [Signal(a)] return model_copy(0.232, y) jitted_func = self.jit_wrap(func) diff --git a/test/dynamics/signals/test_signals.py b/test/dynamics/signals/test_signals.py index cd1b107a2..c291dd4fd 100644 --- a/test/dynamics/signals/test_signals.py +++ b/test/dynamics/signals/test_signals.py @@ -15,26 +15,32 @@ Tests for signals. """ +from functools import partial + import numpy as np from qiskit_dynamics.signals import Signal, DiscreteSignal, DiscreteSignalSum, SignalList from qiskit_dynamics.signals.signals import to_SignalSum -from qiskit_dynamics.array import Array +from qiskit_dynamics.arraylias import DYNAMICS_NUMPY as unp -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import test_array_backends try: from jax import jit, grad import jax.numpy as jnp except ImportError: pass +# Classes that don't explicitly inherit QiskitDynamicsTestCase get no-member errors +# pylint: disable=no-member -class TestSignal(QiskitDynamicsTestCase): +@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"]) +class TestSignal: """Tests for Signal object.""" def setUp(self): - self.signal1 = Signal(lambda t: 0.25, carrier_freq=0.3) + """Setup Signals""" + self.signal1 = Signal(lambda _: 0.25, carrier_freq=0.3) self.signal2 = Signal(lambda t: 2.0 * (t**2), carrier_freq=0.1) self.signal3 = Signal(lambda t: 2.0 * (t**2) + 1j * t, carrier_freq=0.1, phase=-0.1) @@ -51,25 +57,27 @@ def test_envelope(self): def test_envelope_vectorized(self): """Test vectorized evaluation of envelope.""" - t_vals = np.array([1.1, 1.23]) - self.assertAllClose(self.signal1.envelope(t_vals), np.array([0.25, 0.25])) + t_vals = self.asarray([1.1, 1.23]) + self.assertAllClose(self.signal1.envelope(t_vals), self.asarray([0.25, 0.25])) self.assertAllClose( - self.signal2.envelope(t_vals), np.array([2 * (1.1**2), 2 * (1.23**2)]) + self.signal2.envelope(t_vals), self.asarray([2 * (1.1**2), 2 * (1.23**2)]) ) self.assertAllClose( self.signal3.envelope(t_vals), - np.array([2 * (1.1**2) + 1j * 1.1, 2 * (1.23**2) + 1j * 1.23]), + self.asarray([2 * (1.1**2) + 1j * 1.1, 2 * (1.23**2) + 1j * 1.23]), ) - t_vals = np.array([[1.1, 1.23], [0.1, 0.24]]) - self.assertAllClose(self.signal1.envelope(t_vals), np.array([[0.25, 0.25], [0.25, 0.25]])) + t_vals = self.asarray([[1.1, 1.23], [0.1, 0.24]]) + self.assertAllClose( + self.signal1.envelope(t_vals), self.asarray([[0.25, 0.25], [0.25, 0.25]]) + ) self.assertAllClose( self.signal2.envelope(t_vals), - np.array([[2 * (1.1**2), 2 * (1.23**2)], [2 * (0.1**2), 2 * (0.24**2)]]), + self.asarray([[2 * (1.1**2), 2 * (1.23**2)], [2 * (0.1**2), 2 * (0.24**2)]]), ) self.assertAllClose( self.signal3.envelope(t_vals), - np.array( + self.asarray( [ [2 * (1.1**2) + 1j * 1.1, 2 * (1.23**2) + 1j * 1.23], [2 * (0.1**2) + 1j * 0.1, 2 * (0.24**2) + 1j * 0.24], @@ -104,10 +112,10 @@ def test_complex_value(self): def test_complex_value_vectorized(self): """Test vectorized complex_value evaluation.""" - t_vals = np.array([1.1, 1.23]) + t_vals = self.asarray([1.1, 1.23]) self.assertAllClose( self.signal1.complex_value(t_vals), - np.array( + self.asarray( [ 0.25 * np.exp(1j * 2 * np.pi * 0.3 * 1.1), 0.25 * np.exp(1j * 2 * np.pi * 0.3 * 1.23), @@ -116,7 +124,7 @@ def test_complex_value_vectorized(self): ) self.assertAllClose( self.signal2.complex_value(t_vals), - np.array( + self.asarray( [ 2 * (1.1**2) * np.exp(1j * 2 * np.pi * 0.1 * 1.1), 2 * (1.23**2) * np.exp(1j * 2 * np.pi * 0.1 * 1.23), @@ -125,7 +133,7 @@ def test_complex_value_vectorized(self): ) self.assertAllClose( self.signal3.complex_value(t_vals), - np.array( + self.asarray( [ (2 * (1.1**2) + 1j * 1.1) * np.exp(1j * 2 * np.pi * 0.1 * 1.1 + 1j * (-0.1)), (2 * (1.23**2) + 1j * 1.23) @@ -134,10 +142,10 @@ def test_complex_value_vectorized(self): ), ) - t_vals = np.array([[1.1, 1.23], [0.1, 0.24]]) + t_vals = self.asarray([[1.1, 1.23], [0.1, 0.24]]) self.assertAllClose( self.signal1.complex_value(t_vals), - np.array( + self.asarray( [ [ 0.25 * np.exp(1j * 2 * np.pi * 0.3 * 1.1), @@ -152,7 +160,7 @@ def test_complex_value_vectorized(self): ) self.assertAllClose( self.signal2.complex_value(t_vals), - np.array( + self.asarray( [ [ 2 * (1.1**2) * np.exp(1j * 2 * np.pi * 0.1 * 1.1), @@ -167,7 +175,7 @@ def test_complex_value_vectorized(self): ) self.assertAllClose( self.signal3.complex_value(t_vals), - np.array( + self.asarray( [ [ (2 * (1.1**2) + 1j * 1.1) @@ -210,10 +218,10 @@ def test_call(self): def test_call_vectorized(self): """Test vectorized __call__.""" - t_vals = np.array([1.1, 1.23]) + t_vals = self.asarray([1.1, 1.23]) self.assertAllClose( self.signal1(t_vals), - np.array( + self.asarray( [ 0.25 * np.exp(1j * 2 * np.pi * 0.3 * 1.1), 0.25 * np.exp(1j * 2 * np.pi * 0.3 * 1.23), @@ -222,7 +230,7 @@ def test_call_vectorized(self): ) self.assertAllClose( self.signal2(t_vals), - np.array( + self.asarray( [ 2 * (1.1**2) * np.exp(1j * 2 * np.pi * 0.1 * 1.1), 2 * (1.23**2) * np.exp(1j * 2 * np.pi * 0.1 * 1.23), @@ -231,7 +239,7 @@ def test_call_vectorized(self): ) self.assertAllClose( self.signal3(t_vals), - np.array( + self.asarray( [ (2 * (1.1**2) + 1j * 1.1) * np.exp(1j * 2 * np.pi * 0.1 * 1.1 + 1j * (-0.1)), (2 * (1.23**2) + 1j * 1.23) @@ -240,10 +248,10 @@ def test_call_vectorized(self): ).real, ) - t_vals = np.array([[1.1, 1.23], [0.1, 0.24]]) + t_vals = self.asarray([[1.1, 1.23], [0.1, 0.24]]) self.assertAllClose( self.signal1(t_vals), - np.array( + self.asarray( [ [ 0.25 * np.exp(1j * 2 * np.pi * 0.3 * 1.1), @@ -258,7 +266,7 @@ def test_call_vectorized(self): ) self.assertAllClose( self.signal2(t_vals), - np.array( + self.asarray( [ [ 2 * (1.1**2) * np.exp(1j * 2 * np.pi * 0.1 * 1.1), @@ -273,7 +281,7 @@ def test_call_vectorized(self): ) self.assertAllClose( self.signal3(t_vals), - np.array( + self.asarray( [ [ (2 * (1.1**2) + 1j * 1.1) @@ -307,10 +315,12 @@ def test_conjugate(self): ) -class TestConstant(QiskitDynamicsTestCase): +@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"]) +class TestConstant: """Tests for constant signal object.""" def setUp(self): + """Setup constant Signals""" self.constant1 = Signal(1.0) self.constant2 = Signal(3.0 + 1j * 2) @@ -324,12 +334,12 @@ def test_envelope(self): def test_envelope_vectorized(self): """Test vectorized evaluation of envelope.""" - t_vals = np.array([1.1, 1.23]) - self.assertAllClose(self.constant1.envelope(t_vals), np.array([1.0, 1.0])) + t_vals = self.asarray([1.1, 1.23]) + self.assertAllClose(self.constant1.envelope(t_vals), self.asarray([1.0, 1.0])) self.assertAllClose(self.constant2.envelope(t_vals), (3.0 + 2j) * np.ones_like(t_vals)) - t_vals = np.array([[1.1, 1.23], [0.1, 0.24]]) - self.assertAllClose(self.constant1.envelope(t_vals), np.array([[1.0, 1.0], [1.0, 1.0]])) + t_vals = self.asarray([[1.1, 1.23], [0.1, 0.24]]) + self.assertAllClose(self.constant1.envelope(t_vals), self.asarray([[1.0, 1.0], [1.0, 1.0]])) self.assertAllClose(self.constant2.envelope(t_vals), (3.0 + 2j) * np.ones_like(t_vals)) def test_complex_value(self): @@ -342,13 +352,13 @@ def test_complex_value(self): def test_complex_value_vectorized(self): """Test vectorized complex_value evaluation.""" - t_vals = np.array([1.1, 1.23]) - self.assertAllClose(self.constant1.complex_value(t_vals), np.array([1.0, 1.0])) + t_vals = self.asarray([1.1, 1.23]) + self.assertAllClose(self.constant1.complex_value(t_vals), self.asarray([1.0, 1.0])) self.assertAllClose(self.constant2.complex_value(t_vals), (3.0 + 2j) * np.ones_like(t_vals)) - t_vals = np.array([[1.1, 1.23], [0.1, 0.24]]) + t_vals = self.asarray([[1.1, 1.23], [0.1, 0.24]]) self.assertAllClose( - self.constant1.complex_value(t_vals), np.array([[1.0, 1.0], [1.0, 1.0]]) + self.constant1.complex_value(t_vals), self.asarray([[1.0, 1.0], [1.0, 1.0]]) ) self.assertAllClose(self.constant2.complex_value(t_vals), (3.0 + 2j) * np.ones_like(t_vals)) @@ -362,13 +372,13 @@ def test_call(self): def test_call_vectorized(self): """Test vectorized __call__.""" - t_vals = np.array([1.1, 1.23]) - self.assertAllClose(self.constant1(t_vals), np.array([1.0, 1.0])) - self.assertAllClose(self.constant2(t_vals), np.array([3.0, 3.0])) + t_vals = self.asarray([1.1, 1.23]) + self.assertAllClose(self.constant1(t_vals), self.asarray([1.0, 1.0])) + self.assertAllClose(self.constant2(t_vals), self.asarray([3.0, 3.0])) - t_vals = np.array([[1.1, 1.23], [0.1, 0.24]]) - self.assertAllClose(self.constant1(t_vals), np.array([[1.0, 1.0], [1.0, 1.0]])) - self.assertAllClose(self.constant2(t_vals), np.array([[3.0, 3.0], [3.0, 3.0]])) + t_vals = self.asarray([[1.1, 1.23], [0.1, 0.24]]) + self.assertAllClose(self.constant1(t_vals), self.asarray([[1.0, 1.0], [1.0, 1.0]])) + self.assertAllClose(self.constant2(t_vals), self.asarray([[3.0, 3.0], [3.0, 3.0]])) def test_conjugate(self): """Verify conjugate() functioning correctly.""" @@ -377,13 +387,17 @@ def test_conjugate(self): self.assertAllClose(const_conj(1.1), 3.0) -class TestDiscreteSignal(QiskitDynamicsTestCase): +@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"]) +class TestDiscreteSignal: """Tests for DiscreteSignal object.""" def setUp(self): - self.discrete1 = DiscreteSignal(dt=0.5, samples=np.array([1.0, 2.0, 3.0]), carrier_freq=3.0) + """Setup DiscreteSignals""" + self.discrete1 = DiscreteSignal( + dt=0.5, samples=self.asarray([1.0, 2.0, 3.0]), carrier_freq=3.0 + ) self.discrete2 = DiscreteSignal( - dt=0.5, samples=np.array([1.0 + 2j, 2.0 + 1j, 3.0]), carrier_freq=1.0, phase=3.0 + dt=0.5, samples=self.asarray([1.0 + 2j, 2.0 + 1j, 3.0]), carrier_freq=1.0, phase=3.0 ) def test_envelope(self): @@ -405,14 +419,14 @@ def test_envelope_outside(self): def test_envelope_vectorized(self): """Test vectorized evaluation of envelope.""" - t_vals = np.array([0.1, 1.23]) - self.assertAllClose(self.discrete1.envelope(t_vals), np.array([1.0, 3.0])) - self.assertAllClose(self.discrete2.envelope(t_vals), np.array([1.0 + 2j, 3.0])) + t_vals = self.asarray([0.1, 1.23]) + self.assertAllClose(self.discrete1.envelope(t_vals), self.asarray([1.0, 3.0])) + self.assertAllClose(self.discrete2.envelope(t_vals), self.asarray([1.0 + 2j, 3.0])) - t_vals = np.array([[0.8, 1.23], [0.1, 0.24]]) - self.assertAllClose(self.discrete1.envelope(t_vals), np.array([[2.0, 3.0], [1.0, 1.0]])) + t_vals = self.asarray([[0.8, 1.23], [0.1, 0.24]]) + self.assertAllClose(self.discrete1.envelope(t_vals), self.asarray([[2.0, 3.0], [1.0, 1.0]])) self.assertAllClose( - self.discrete2.envelope(t_vals), np.array([[2.0 + 1j, 3.0], [1 + 2j, 1.0 + 2j]]) + self.discrete2.envelope(t_vals), self.asarray([[2.0 + 1j, 3.0], [1 + 2j, 1.0 + 2j]]) ) def test_complex_value(self): @@ -432,23 +446,23 @@ def test_complex_value(self): def test_complex_value_vectorized(self): """Test vectorized complex_value evaluation.""" - t_vals = np.array([0.1, 1.23]) + t_vals = self.asarray([0.1, 1.23]) phases = np.exp(1j * 2 * np.pi * 3.0 * t_vals) - self.assertAllClose(self.discrete1.complex_value(t_vals), np.array([1.0, 3.0]) * phases) + self.assertAllClose(self.discrete1.complex_value(t_vals), self.asarray([1.0, 3.0]) * phases) phases = np.exp(1j * 2 * np.pi * 1.0 * t_vals + 1j * 3.0) self.assertAllClose( - self.discrete2.complex_value(t_vals), np.array([1.0 + 2j, 3.0]) * phases + self.discrete2.complex_value(t_vals), self.asarray([1.0 + 2j, 3.0]) * phases ) - t_vals = np.array([[0.8, 1.23], [0.1, 0.24]]) + t_vals = self.asarray([[0.8, 1.23], [0.1, 0.24]]) phases = np.exp(1j * 2 * np.pi * 3.0 * t_vals) self.assertAllClose( - self.discrete1.complex_value(t_vals), np.array([[2.0, 3.0], [1.0, 1.0]]) * phases + self.discrete1.complex_value(t_vals), self.asarray([[2.0, 3.0], [1.0, 1.0]]) * phases ) phases = np.exp(1j * 2 * np.pi * 1.0 * t_vals + 1j * 3.0) self.assertAllClose( self.discrete2.complex_value(t_vals), - np.array([[2.0 + 1j, 3.0], [1 + 2j, 1.0 + 2j]]) * phases, + self.asarray([[2.0 + 1j, 3.0], [1 + 2j, 1.0 + 2j]]) * phases, ) def test_call(self): @@ -467,21 +481,21 @@ def test_call(self): def test_call_vectorized(self): """Test vectorized __call__.""" - t_vals = np.array([0.1, 1.23]) + t_vals = self.asarray([0.1, 1.23]) phases = np.exp(1j * 2 * np.pi * 3.0 * t_vals) - self.assertAllClose(self.discrete1(t_vals), np.real(np.array([1.0, 3.0]) * phases)) + self.assertAllClose(self.discrete1(t_vals), np.real(self.asarray([1.0, 3.0]) * phases)) phases = np.exp(1j * 2 * np.pi * 1.0 * t_vals + 1j * 3.0) - self.assertAllClose(self.discrete2(t_vals), np.real(np.array([1.0 + 2j, 3.0]) * phases)) + self.assertAllClose(self.discrete2(t_vals), np.real(self.asarray([1.0 + 2j, 3.0]) * phases)) - t_vals = np.array([[0.8, 1.23], [0.1, 0.24]]) + t_vals = self.asarray([[0.8, 1.23], [0.1, 0.24]]) phases = np.exp(1j * 2 * np.pi * 3.0 * t_vals) self.assertAllClose( - self.discrete1(t_vals), np.real(np.array([[2.0, 3.0], [1.0, 1.0]]) * phases) + self.discrete1(t_vals), np.real(self.asarray([[2.0, 3.0], [1.0, 1.0]]) * phases) ) phases = np.exp(1j * 2 * np.pi * 1.0 * t_vals + 1j * 3.0) self.assertAllClose( self.discrete2(t_vals), - np.real(np.array([[2.0 + 1j, 3.0], [1 + 2j, 1.0 + 2j]]) * phases), + np.real(self.asarray([[2.0 + 1j, 3.0], [1 + 2j, 1.0 + 2j]]) * phases), ) def test_conjugate(self): @@ -495,9 +509,9 @@ def test_conjugate(self): def test_add_samples(self): """Verify that add_samples function works correctly""" - discrete1 = DiscreteSignal(dt=0.5, samples=np.array([]), carrier_freq=3.0) + discrete1 = DiscreteSignal(dt=0.5, samples=self.asarray([]), carrier_freq=3.0) discrete2 = DiscreteSignal( - dt=0.5, samples=np.array([1.0 + 2j, 2.0 + 1j, 3.0]), carrier_freq=1.0, phase=3.0 + dt=0.5, samples=self.asarray([1.0 + 2j, 2.0 + 1j, 3.0]), carrier_freq=1.0, phase=3.0 ) discrete3 = DiscreteSignal(dt=0.5, samples=[], carrier_freq=3.0) @@ -511,10 +525,11 @@ def test_add_samples(self): self.assertAllClose(discrete3.samples, [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0]) -class TestSignalSum(QiskitDynamicsTestCase): +class TestSignalSum: """Test evaluation functions for ``SignalSum``.""" def setUp(self): + """Setup SignalSums""" self.signal1 = Signal(np.vectorize(lambda t: 0.25), carrier_freq=0.3) self.signal2 = Signal(lambda t: 2.0 * (t**2), carrier_freq=0.1) self.signal3 = Signal(lambda t: 2.0 * (t**2) + 1j * t, carrier_freq=0.1, phase=-0.1) @@ -560,7 +575,7 @@ def test_envelope(self): def test_envelope_vectorized(self): """Test vectorized envelope evaluation.""" - t_vals = np.array([0.0, 1.23]) + t_vals = self.asarray([0.0, 1.23]) self.assertAllClose( self.sig_sum1.envelope(t_vals), [[self.signal1.envelope(t), self.signal2.envelope(t)] for t in t_vals], @@ -581,7 +596,7 @@ def test_envelope_vectorized(self): for t in t_vals ], ) - t_vals = np.array([[0.0, 1.23], [0.1, 2.0]]) + t_vals = self.asarray([[0.0, 1.23], [0.1, 2.0]]) self.assertAllClose( self.sig_sum1.envelope(t_vals), [ @@ -643,7 +658,7 @@ def test_complex_value(self): def test_complex_value_vectorized(self): """Test vectorized complex_value evaluation.""" - t_vals = np.array([0.0, 1.23]) + t_vals = self.asarray([0.0, 1.23]) self.assertAllClose( self.sig_sum1.complex_value(t_vals), [self.signal1.complex_value(t) + self.signal2.complex_value(t) for t in t_vals], @@ -656,7 +671,7 @@ def test_complex_value_vectorized(self): self.double_sig_sum.complex_value(t_vals), [self.signal1.complex_value(t) + self.signal3.complex_value(t) for t in t_vals], ) - t_vals = np.array([[0.0, 1.23], [0.1, 2.0]]) + t_vals = self.asarray([[0.0, 1.23], [0.1, 2.0]]) self.assertAllClose( self.sig_sum1.complex_value(t_vals), [ @@ -692,7 +707,7 @@ def test_call(self): def test_call_vectorized(self): """Test vectorized __call__.""" - t_vals = np.array([0.0, 1.23]) + t_vals = self.asarray([0.0, 1.23]) self.assertAllClose( self.sig_sum1(t_vals), [self.signal1(t) + self.signal2(t) for t in t_vals] ) @@ -702,7 +717,7 @@ def test_call_vectorized(self): self.assertAllClose( self.double_sig_sum(t_vals), [self.signal1(t) + self.signal3(t) for t in t_vals] ) - t_vals = np.array([[0.0, 1.23], [0.1, 2.0]]) + t_vals = self.asarray([[0.0, 1.23], [0.1, 2.0]]) self.assertAllClose( self.sig_sum1(t_vals), [[self.signal1(t) + self.signal2(t) for t in t_row] for t_row in t_vals], @@ -732,6 +747,7 @@ class TestDiscreteSignalSum(TestSignalSum): """Tests for DiscreteSignalSum.""" def setUp(self): + """Setup DiscreteSignalSums""" self.signal1 = Signal(np.vectorize(lambda t: 0.25), carrier_freq=0.3) self.signal2 = Signal(lambda t: 2.0 * (t**2), carrier_freq=0.1) self.signal3 = Signal(lambda t: 2.0 * (t**2) + 1j * t, carrier_freq=0.1, phase=-0.1) @@ -757,10 +773,18 @@ def test_empty_DiscreteSignal_to_sum(self): self.assertTrue(empty_sum.samples.shape == (1, 0)) -class TestSignalList(QiskitDynamicsTestCase): +test_array_backends(TestSignalSum, array_libraries=["numpy", "jax", "array_numpy", "array_jax"]) +test_array_backends( + TestDiscreteSignalSum, array_libraries=["numpy", "jax", "array_numpy", "array_jax"] +) + + +@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"]) +class TestSignalList: """Test cases for SignalList class.""" def setUp(self): + """Setup a SignalList""" self.sig = Signal(lambda t: t, carrier_freq=3.0) self.const = Signal(5.0) self.discrete_sig = DiscreteSignal( @@ -774,9 +798,9 @@ def setUp(self): def test_eval(self): """Test evaluation of signal sum.""" - t_vals = np.array([0.12, 0.23, 1.23]) + t_vals = self.asarray([0.12, 0.23, 1.23]) - expected = np.array( + expected = self.asarray( [ self.sig(t_vals) + self.const(t_vals), self.sig(t_vals) * self.discrete_sig(t_vals), @@ -789,9 +813,9 @@ def test_eval(self): def test_complex_value(self): """Test evaluation of signal sum.""" - t_vals = np.array([0.12, 0.23, 1.23]) + t_vals = self.asarray([0.12, 0.23, 1.23]) - expected = np.array( + expected = self.asarray( [ self.sig.complex_value(t_vals) + self.const.complex_value(t_vals), np.real(self.sig.complex_value(t_vals)) * self.discrete_sig.complex_value(t_vals), @@ -804,27 +828,26 @@ def test_complex_value(self): def test_drift(self): """Test drift evaluation.""" - expected = np.array([self.const(0.0), 0, self.const(0.0)]) + expected = self.asarray([self.const(0.0), 0, self.const(0.0)]) self.assertAllClose(self.sig_list.drift, expected) def test_construction_with_numbers(self): """Test construction with non-wrapped constant values.""" sig_list = SignalList([4.0, 2.0, Signal(lambda t: t)]) - # pylint: disable=no-member self.assertTrue(sig_list[0][0].is_constant) - # pylint: disable=no-member self.assertTrue(sig_list[1][0].is_constant) - # pylint: disable=no-member self.assertFalse(sig_list[2][0].is_constant) - self.assertAllClose(sig_list(3.0), np.array([4.0, 2.0, 3.0])) + self.assertAllClose(sig_list(3.0), self.asarray([4.0, 2.0, 3.0])) -class TestSignalCollection(QiskitDynamicsTestCase): +@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"]) +class TestSignalCollection: """Test cases for SignalCollection functionality.""" def setUp(self): + """Setup SignalCollections""" self.sig1 = Signal(lambda t: t, carrier_freq=0.1) self.sig2 = Signal(lambda t: t + 1j * t**2, carrier_freq=3.0, phase=1.0) self.sig3 = Signal(lambda t: t + 1j * t**2, carrier_freq=3.0, phase=1.2) @@ -843,7 +866,7 @@ def test_SignalSum_subscript(self): sub02 = self.sig_sum[[0, 2]] self.assertTrue(len(sub02) == 2) - t_vals = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + t_vals = self.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) self.assertAllClose(sub02(t_vals), self.sig1(t_vals) + self.sig3(t_vals)) def test_DiscreteSignalSum_subscript(self): @@ -851,7 +874,7 @@ def test_DiscreteSignalSum_subscript(self): sub02 = self.discrete_sig_sum[[0, 2]] self.assertTrue(len(sub02) == 2) - t_vals = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) / 4.0 + t_vals = self.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) / 4.0 self.assertAllClose(sub02(t_vals), self.discrete_sig1(t_vals) + self.discrete_sig3(t_vals)) def test_SignalSum_iterator(self): @@ -873,34 +896,12 @@ def test_DiscreteSignalSum_iterator(self): self.assertAllClose(sum_val, self.discrete_sig_sum(3.0)) -class TestSignalJax(TestSignal, TestJaxBase): - """Jax version of TestSignal.""" - - -class TestConstantJax(TestSignal, TestJaxBase): - """Jax version of TestConstant.""" - - -class TestDiscreteSignalJax(TestDiscreteSignal, TestJaxBase): - """Jax version of TestDiscreteSignal.""" - - -class TestSignalSumJax(TestSignalSum, TestJaxBase): - """Jax version of TestSignalSum.""" - - -class TestDiscreteSignalSumJax(TestDiscreteSignalSum, TestJaxBase): - """Jax version of TestSignalSum.""" - - -class TestSignalListJax(TestSignalList, TestJaxBase): - """Jax version of TestSignalList.""" - - -class TestSignalsJaxTransformations(QiskitDynamicsTestCase, TestJaxBase): +@partial(test_array_backends, array_libraries=["jax"]) +class TestSignalsJaxTransformations: """Test cases for jax transformations of signals.""" def setUp(self): + """Setup Signals""" self.signal = Signal(lambda t: t**2, carrier_freq=3.0) self.constant = Signal(3 * np.pi) self.discrete_signal = DiscreteSignal( @@ -924,8 +925,8 @@ def test_jit_grad_constant_construct(self): """Test jitting and grad through a function which constructs a constant signal.""" def eval_const(a): - a = Array(a) - return Signal(a)(1.1).data + a = unp.asarray(a) + return Signal(a)(1.1) jit_eval = jit(eval_const) self.assertAllClose(jit_eval(3.0), 3.0) @@ -935,7 +936,7 @@ def eval_const(a): # validate that is_constant is being properly set def eval_const_conditional(a): - a = Array(a) + a = unp.asarray(a) sig = Signal(a) if sig.is_constant: @@ -952,9 +953,9 @@ def test_jit_grad_carrier_freq_construct(self): """ def eval_sig(a, v, t): - a = Array(a) - v = Array(v) - return Array(Signal(a, v)(t)).data + a = unp.asarray(a) + v = unp.asarray(v) + return unp.asarray(Signal(a, v)(t)) jit_eval = jit(eval_sig) self.assertAllClose(jit_eval(1.0, 1.0, 1.0), 1.0) @@ -964,9 +965,9 @@ def eval_sig(a, v, t): def test_signal_list_jit_eval(self): """Test jit-compilation of SignalList evaluation.""" - call_jit = jit(lambda t: Array(self.signal_list(t)).data) + call_jit = jit(lambda t: unp.asarray(self.signal_list(t))) - t_vals = np.array([0.123, 0.5324, 1.232]) + t_vals = self.asarray([0.123, 0.5324, 1.232]) self.assertAllClose(call_jit(t_vals), self.signal_list(t_vals)) def test_jit_grad_eval(self): @@ -1022,20 +1023,22 @@ def test_jit_grad_eval(self): def _test_jit_signal_eval(self, signal, t=2.1): """jit compilation and evaluation of main signal functions.""" - sig_call_jit = jit(lambda t: Array(signal(t)).data) + sig_call_jit = jit(lambda t: self.asarray(signal(t))) self.assertAllClose(sig_call_jit(t), signal(t)) - sig_envelope_jit = jit(lambda t: Array(signal.envelope(t)).data) + sig_envelope_jit = jit(lambda t: unp.asarray(signal.envelope(t))) self.assertAllClose(sig_envelope_jit(t), signal.envelope(t)) - sig_complex_value_jit = jit(lambda t: Array(signal.complex_value(t)).data) + sig_complex_value_jit = jit(lambda t: unp.asarray(signal.complex_value(t))) self.assertAllClose(sig_complex_value_jit(t), signal.complex_value(t)) def _test_grad_eval(self, signal, t, sig_deriv_val, complex_deriv_val): """Test chained grad and jit compilation.""" - sig_call_jit = jit(grad(lambda t: Array(signal(t)).data)) + sig_call_jit = jit(grad(lambda t: unp.asarray(signal(t)))) self.assertAllClose(sig_call_jit(t), sig_deriv_val) - sig_complex_value_jit_re = jit(grad(lambda t: np.real(Array(signal.complex_value(t))).data)) + sig_complex_value_jit_re = jit( + grad(lambda t: unp.real(unp.asarray(signal.complex_value(t)))) + ) sig_complex_value_jit_imag = jit( - grad(lambda t: np.imag(Array(signal.complex_value(t))).data) + grad(lambda t: np.imag(unp.asarray(signal.complex_value(t)))) ) self.assertAllClose(sig_complex_value_jit_re(t), np.real(complex_deriv_val)) self.assertAllClose(sig_complex_value_jit_imag(t), np.imag(complex_deriv_val)) diff --git a/test/dynamics/signals/test_signals_algebra.py b/test/dynamics/signals/test_signals_algebra.py index e3e974df6..434046756 100644 --- a/test/dynamics/signals/test_signals_algebra.py +++ b/test/dynamics/signals/test_signals_algebra.py @@ -269,7 +269,7 @@ def _test_jit_sum_eval(self, sig1, sig2, t_vals): def eval_func(t): sig_sum = sig1 + sig2 - return sig_sum(t).data + return sig_sum(t) jit_eval_func = jit(eval_func) self.assertAllClose(jit_eval_func(t_vals), eval_func(t_vals)) @@ -279,7 +279,7 @@ def _test_grad_jit_sum_eval(self, sig1, sig2, t): def eval_func(t): sig_sum = sig1 + sig2 - return sig_sum(t).data + return sig_sum(t) jit_eval_func = jit(grad(eval_func)) jit_eval_func(t) @@ -289,7 +289,7 @@ def _test_jit_prod_eval(self, sig1, sig2, t_vals): def eval_func(t): sig_sum = sig1 * sig2 - return sig_sum(t).data + return sig_sum(t) jit_eval_func = jit(eval_func) self.assertAllClose(jit_eval_func(t_vals), eval_func(t_vals)) @@ -299,7 +299,7 @@ def _test_grad_jit_prod_eval(self, sig1, sig2, t): def eval_func(t): sig_sum = sig1 * sig2 - return sig_sum(t).data + return sig_sum(t) jit_eval_func = jit(grad(eval_func)) jit_eval_func(t) diff --git a/test/dynamics/solvers/test_dyson_magnus_solvers.py b/test/dynamics/solvers/test_dyson_magnus_solvers.py index ec9989f37..7243a0392 100644 --- a/test/dynamics/solvers/test_dyson_magnus_solvers.py +++ b/test/dynamics/solvers/test_dyson_magnus_solvers.py @@ -21,6 +21,7 @@ from qiskit_dynamics import Signal, Solver, DysonSolver, MagnusSolver from qiskit_dynamics.array import Array +from qiskit_dynamics import DYNAMICS_NUMPY as unp from qiskit_dynamics.solvers.perturbative_solvers.expansion_model import ( _construct_DCT, @@ -85,7 +86,7 @@ def build_testing_objects(obj, integration_method="DOP853"): r = 0.2 def gaussian(amp, sig, t0, t): - return amp * np.exp(-((t - t0) ** 2) / (2 * sig**2)) + return amp * unp.exp(-((t - t0) ** 2) / (2 * sig**2)) # specifications for generating envelope amp = 1.0 # amplitude @@ -94,13 +95,12 @@ def gaussian(amp, sig, t0, t): T = 7 * sig # end of signal # Function to define gaussian envelope, using gaussian wave function - gaussian_envelope = lambda t: gaussian(Array(amp), Array(sig), Array(t0), Array(t)) + gaussian_envelope = lambda t: gaussian(amp, sig, t0, t) obj.gauss_signal = Signal(gaussian_envelope, carrier_freq=5.0) dt = 0.0125 obj.n_steps = int(T // dt) // 3 - hamiltonian_operators = 2 * np.pi * r * np.array([[[0.0, 1.0], [1.0, 0.0]]]) / 2 static_hamiltonian = 2 * np.pi * 5.0 * np.array([[1.0, 0.0], [0.0, -1.0]]) / 2 @@ -344,7 +344,7 @@ def func(c): t0=0.0, n_steps=self.n_steps, y0=np.eye(2, dtype=complex), - signals=[Signal(Array(c), carrier_freq=5.0)], + signals=[Signal(c, carrier_freq=5.0)], ).y[-1] return dyson_yf @@ -362,7 +362,7 @@ def func(c): t0=0.0, n_steps=self.n_steps, y0=np.eye(2, dtype=complex), - signals=[Signal(Array(c), carrier_freq=5.0)], + signals=[Signal(c, carrier_freq=5.0)], ).y[-1] return magnus_yf diff --git a/test/dynamics/solvers/test_solver_classes.py b/test/dynamics/solvers/test_solver_classes.py index 492d3cbdb..86f24884b 100644 --- a/test/dynamics/solvers/test_solver_classes.py +++ b/test/dynamics/solvers/test_solver_classes.py @@ -24,7 +24,6 @@ from qiskit_dynamics import Solver, Signal, DiscreteSignal, solve_lmde from qiskit_dynamics.models import HamiltonianModel, LindbladModel, rotating_wave_approximation -from qiskit_dynamics.array import Array from qiskit_dynamics.type_utils import to_array from qiskit_dynamics.solvers.solver_classes import organize_signals_to_channels @@ -710,7 +709,7 @@ def func(a): yf = solver.solve( t_span=np.array([0.0, 0.1]), y0=np.array([0.0, 1.0]), - signals=[Signal(Array(a), 5.0)], + signals=[Signal(a, 5.0)], method=self.method, ).y[-1] return yf diff --git a/test/dynamics/solvers/test_solver_functions.py b/test/dynamics/solvers/test_solver_functions.py index 45e29d735..66d8cdcdb 100644 --- a/test/dynamics/solvers/test_solver_functions.py +++ b/test/dynamics/solvers/test_solver_functions.py @@ -220,7 +220,7 @@ def test_pseudo_random_jit_grad(self): def func(a): model_copy = self.pseudo_random_model.copy() - model_copy.signals = [Signal(Array(a), carrier_freq=1.0)] + model_copy.signals = [Signal(a, carrier_freq=1.0)] results = self.solve(model_copy, t_span=[0.0, 0.1], y0=self.pseudo_random_y0) return results.y[-1] @@ -367,7 +367,7 @@ def setUp(self): # simulate directly out of frame def pseudo_random_rhs(t, y=None): - op = self.static_operator + self.pseudo_random_signal(t).data * self.operators[0] + op = self.static_operator + self.pseudo_random_signal(t) * self.operators[0] op = -1j * op if y is None: return op