Skip to content

Commit

Permalink
updating solver_classes
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli committed Dec 18, 2023
1 parent 298b606 commit 4eb6050
Showing 1 changed file with 39 additions and 37 deletions.
76 changes: 39 additions & 37 deletions qiskit_dynamics/solvers/solver_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import numpy as np

from scipy.integrate._ivp.ivp import OdeResult # pylint: disable=unused-import
from scipy.integrate._ivp.ivp import OdeResult

from qiskit import QiskitError
from qiskit.pulse import Schedule, ScheduleBlock
Expand All @@ -34,6 +34,9 @@
from qiskit.quantum_info.states.quantum_state import QuantumState
from qiskit.quantum_info import SuperOp, Operator, DensityMatrix

from qiskit_dynamics import ArrayLike
from qiskit_dynamics import DYNAMICS_NUMPY as unp

from qiskit_dynamics.models import (
HamiltonianModel,
LindbladModel,
Expand All @@ -42,8 +45,6 @@
)
from qiskit_dynamics.signals import Signal, DiscreteSignal, SignalList
from qiskit_dynamics.pulse import InstructionToSignals
from qiskit_dynamics.array import Array
from qiskit_dynamics.dispatch.dispatch import Dispatch

from .solver_functions import solve_lmde, _is_diffrax_method
from .solver_utils import (
Expand Down Expand Up @@ -180,19 +181,20 @@ class Solver:

def __init__(
self,
static_hamiltonian: Optional[Array] = None,
hamiltonian_operators: Optional[Array] = None,
static_dissipators: Optional[Array] = None,
dissipator_operators: Optional[Array] = None,
static_hamiltonian: Optional[ArrayLike] = None,
hamiltonian_operators: Optional[ArrayLike] = None,
static_dissipators: Optional[ArrayLike] = None,
dissipator_operators: Optional[ArrayLike] = None,
hamiltonian_channels: Optional[List[str]] = None,
dissipator_channels: Optional[List[str]] = None,
channel_carrier_freqs: Optional[dict] = None,
dt: Optional[float] = None,
rotating_frame: Optional[Union[Array, RotatingFrame]] = None,
rotating_frame: Optional[Union[ArrayLike, RotatingFrame]] = None,
in_frame_basis: bool = False,
evaluation_mode: str = "dense",
array_library: Optional[str] = None,
vectorized: Optional[bool] = None,
rwa_cutoff_freq: Optional[float] = None,
rwa_carrier_freqs: Optional[Union[Array, Tuple[Array, Array]]] = None,
rwa_carrier_freqs: Optional[Union[ArrayLike, Tuple[ArrayLike, ArrayLike]]] = None,
validate: bool = True,
):
"""Initialize solver with model information.
Expand All @@ -218,10 +220,9 @@ def __init__(
in_frame_basis: Whether to represent the model in the basis in which the rotating
frame operator is diagonalized. See class documentation for a more
detailed explanation on how this argument affects object behaviour.
evaluation_mode: Method for model evaluation. See documentation for
``HamiltonianModel.evaluation_mode`` or
``LindbladModel.evaluation_mode``.
(if dissipators in model) for valid modes.
array_library: Array library to use for storing operators of underlying model.
vectorized: If including dissipator terms, whether or not to construct the
:class:`.LindbladModel` in vectorized form.
rwa_cutoff_freq: Rotating wave approximation cutoff frequency. If ``None``, no
approximation is made.
rwa_carrier_freqs: Carrier frequencies to use for rotating wave approximation.
Expand Down Expand Up @@ -312,7 +313,7 @@ def __init__(
operators=hamiltonian_operators,
rotating_frame=rotating_frame,
in_frame_basis=in_frame_basis,
evaluation_mode=evaluation_mode,
array_library=array_library,
validate=validate,
)
else:
Expand All @@ -323,7 +324,8 @@ def __init__(
dissipator_operators=dissipator_operators,
rotating_frame=rotating_frame,
in_frame_basis=in_frame_basis,
evaluation_mode=evaluation_mode,
array_library=array_library,
vectorized=vectorized,
validate=validate,
)

Expand Down Expand Up @@ -385,8 +387,8 @@ def model(self) -> Union[HamiltonianModel, LindbladModel]:

def solve(
self,
t_span: Array,
y0: Union[Array, QuantumState, BaseOperator],
t_span: ArrayLike,
y0: Union[ArrayLike, QuantumState, BaseOperator],
signals: Optional[
Union[
List[Union[Schedule, ScheduleBlock]],
Expand Down Expand Up @@ -440,7 +442,7 @@ def solve(
- Model type
- ``yf`` type
- Description
* - ``Array``, ``np.ndarray``, ``Operator``
* - ``ArrayLike``, ``np.ndarray``, ``Operator``
- Any
- Same as ``y0``
- Solves either the Schrodinger equation or Lindblad equation
Expand Down Expand Up @@ -468,8 +470,8 @@ def solve(
* - ``QuantumChannel``
- ``LindbladModel``
- ``SuperOp``
- Solves the vectorized Lindblad equation with initial state ``y0``.
``evaluation_mode`` must be set to a vectorized option.
- Solves the vectorized Lindblad equation with initial state ``y0``. ``vectorized``
must be set to ``True``.
In some cases (e.g. if using JAX), wrapping the returned states in the type
given in the ``yf`` type column above may be undesirable. Setting
Expand Down Expand Up @@ -552,8 +554,8 @@ def solve(

def _solve_list(
self,
t_span_list: List[Array],
y0_list: List[Union[Array, QuantumState, BaseOperator]],
t_span_list: List[ArrayLike],
y0_list: List[Union[ArrayLike, QuantumState, BaseOperator]],
signals_list: Optional[
Union[List[Schedule], List[List[Signal]], List[Tuple[List[Signal], List[Signal]]]]
] = None,
Expand Down Expand Up @@ -588,8 +590,8 @@ def _solve_list(

def _solve_schedule_list_jax(
self,
t_span_list: List[Array],
y0_list: List[Union[Array, QuantumState, BaseOperator]],
t_span_list: List[ArrayLike],
y0_list: List[Union[ArrayLike, QuantumState, BaseOperator]],
schedule_list: List[Schedule],
convert_results: bool = True,
**kwargs,
Expand Down Expand Up @@ -637,7 +639,7 @@ def sim_function(t_span, y0, all_samples, y0_input, y0_cls):
# reset signals to ensure purity
self.model.signals = model_sigs

return Array(results.t).data, Array(results.y).data
return results.t, results.y

jit_sim_function = jit(sim_function, static_argnums=(4,))

Expand All @@ -657,9 +659,9 @@ def sim_function(t_span, y0, all_samples, y0_input, y0_cls):
all_samples[idx, 0 : len(sig.samples)] = np.array(sig.samples)

results_t, results_y = jit_sim_function(
Array(t_span).data, Array(y0).data, all_samples, Array(y0_input).data, y0_cls
unp.asarray(t_span), unp.asarray(y0), unp.asarray(all_samples), unp.asarray(y0_input), y0_cls
)
results = OdeResult(t=results_t, y=Array(results_y, backend="jax", dtype=complex))
results = OdeResult(t=results_t, y=results_y)

if y0_cls is not None and convert_results:
results.y = [state_type_wrapper(yi) for yi in results.y]
Expand Down Expand Up @@ -698,7 +700,7 @@ def _schedule_to_signals(self, schedule: Schedule):
)


def initial_state_converter(obj: Any) -> Tuple[Array, Type, Callable]:
def initial_state_converter(obj: Any) -> Tuple[ArrayLike, Type, Callable]:
"""Convert initial state object to an Array, the type of the initial input, and return
function for constructing a state of the same type.
Expand All @@ -710,23 +712,23 @@ def initial_state_converter(obj: Any) -> Tuple[Array, Type, Callable]:
"""
# pylint: disable=invalid-name
y0_cls = None
if isinstance(obj, Array):
if isinstance(obj, ArrayLike):
y0, y0_cls, wrapper = obj, None, lambda x: x
if isinstance(obj, QuantumState):
y0, y0_cls = Array(obj.data), obj.__class__
y0, y0_cls = obj.data, obj.__class__
wrapper = lambda x: y0_cls(np.array(x), dims=obj.dims())
elif isinstance(obj, QuantumChannel):
y0, y0_cls = Array(SuperOp(obj).data), SuperOp
y0, y0_cls = SuperOp(obj).data, SuperOp
wrapper = lambda x: SuperOp(
np.array(x), input_dims=obj.input_dims(), output_dims=obj.output_dims()
)
elif isinstance(obj, (BaseOperator, Gate, QuantumCircuit)):
y0, y0_cls = Array(Operator(obj.data)), Operator
y0, y0_cls = Operator(obj.data), Operator
wrapper = lambda x: Operator(
np.array(x), input_dims=obj.input_dims(), output_dims=obj.output_dims()
)
else:
y0, y0_cls, wrapper = Array(obj), None, lambda x: x
y0, y0_cls, wrapper = unp.asarray(obj), None, lambda x: x

return y0, y0_cls, wrapper

Expand Down Expand Up @@ -794,15 +796,15 @@ def validate_and_format_initial_state(y0: any, model: Union[HamiltonianModel, Li
def format_final_states(y, model, y0_input, y0_cls):
"""Format final states for a single simulation."""

y = Array(y)
y = unp.asarray(y)

if y0_cls is DensityMatrix and isinstance(model, HamiltonianModel):
# conjugate by unitary
return y @ y0_input @ y.conj().transpose((0, 2, 1))
elif y0_cls is SuperOp and isinstance(model, HamiltonianModel):
# convert to SuperOp and compose
return (
np.einsum("nka,nlb->nklab", y.conj(), y).reshape(
unp.einsum("nka,nlb->nklab", y.conj(), y).reshape(
y.shape[0], y.shape[1] ** 2, y.shape[1] ** 2
)
@ y0_input
Expand Down Expand Up @@ -903,7 +905,7 @@ def _nested_ndim(x):
"""Determine the 'ndim' of x, which could be composed of nested lists and array types."""
if isinstance(x, (list, tuple)):
return 1 + _nested_ndim(x[0])
elif issubclass(type(x), Dispatch.REGISTERED_TYPES) or isinstance(x, Array):
elif hasattr(x, "ndim"):
return x.ndim

# assume scalar
Expand Down

0 comments on commit 4eb6050

Please sign in to comment.