Skip to content

Commit

Permalink
Arraylias integration - Signal class - (#269)
Browse files Browse the repository at this point in the history
Co-authored-by: DanPuzzuoli <[email protected]>
  • Loading branch information
to24toro and DanPuzzuoli authored Oct 30, 2023
1 parent 594481c commit 7333814
Show file tree
Hide file tree
Showing 18 changed files with 352 additions and 335 deletions.
8 changes: 4 additions & 4 deletions docs/tutorials/optimizing_pulse_sequence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ example of one.
.. jupyter-execute::

from qiskit_dynamics import DiscreteSignal
from qiskit_dynamics.array import Array
from qiskit_dynamics.signals import Convolution

import jax.numpy as jnp

# define convolution filter
def gaus(t):
sigma = 15
Expand All @@ -128,13 +129,12 @@ example of one.

# define function mapping parameters to signals
def signal_mapping(params):
samples = Array(params)

# map samples into [-1, 1]
bounded_samples = np.arctan(samples) / (np.pi / 2)
bounded_samples = jnp.arctan(params) / (np.pi / 2)

# pad with 0 at beginning
padded_samples = np.append(Array([0], dtype=complex), bounded_samples)
padded_samples = jnp.append(jnp.array([0], dtype=complex), bounded_samples)

# apply filter
output_signal = convolution(DiscreteSignal(dt=1., samples=padded_samples))
Expand Down
6 changes: 3 additions & 3 deletions docs/userguide/how_to_configure_simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ further highlight the benefits of the sparse representation.

static_hamiltonian = 2 * np.pi * v * N + np.pi * anharm * N * (N - np.eye(dim))
drive_hamiltonian = 2 * np.pi * r * (a + adag)
drive_signal = Signal(Array(1.), carrier_freq=v)
drive_signal = Signal(1., carrier_freq=v)

y0 = np.zeros(dim, dtype=complex)
y0[1] = 1.
Expand All @@ -266,7 +266,7 @@ amplitude, and just-in-time compile it using JAX.
)

def dense_func(amp):
drive_signal = Signal(Array(amp), carrier_freq=v)
drive_signal = Signal(amp, carrier_freq=v)
res = solver.solve(
t_span=[0., T],
y0=y0,
Expand All @@ -292,7 +292,7 @@ diagonal, but we explicitly highlight the need for this.
evaluation_mode='sparse')

def sparse_func(amp):
drive_signal = Signal(Array(amp), carrier_freq=v)
drive_signal = Signal(amp, carrier_freq=v)
res = sparse_solver.solve(
t_span=[0., T],
y0=y0,
Expand Down
1 change: 0 additions & 1 deletion docs/userguide/how_to_use_jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ before setting the signals, to ensure the simulation function remains pure.
def sim_function(amp):

# define a constant signal
amp = Array(amp)
signals = [Signal(amp, carrier_freq=w)]

# simulate and return results
Expand Down
2 changes: 1 addition & 1 deletion docs/userguide/how_to_use_pulse_schedule_for_jax_jit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@ JAX-compiled (or more generally, JAX-transformed).
# convert from a pulse schedule to a list of signals
converter = InstructionToSignals(dt, carriers={"d0": w})

return converter.get_signals(schedule)[0].samples.data
return converter.get_signals(schedule)[0].samples

jax.jit(jit_func)(0.4)
4 changes: 2 additions & 2 deletions docs/userguide/perturbative_solvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ these functions gives a sense of the speeds attainable by these solvers.
"""For a given envelope amplitude, simulate the final unitary using the
Dyson solver.
"""
drive_signal = Signal(lambda t: Array(amp) * envelope_func(t), carrier_freq=v)
drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
return dyson_solver.solve(
signals=[drive_signal],
y0=np.eye(dim, dtype=complex),
Expand Down Expand Up @@ -220,7 +220,7 @@ accuracy and simulation speed.

# specify tolerance as an argument to run the simulation at different tolerances
def ode_sim(amp, tol):
drive_signal = Signal(lambda t: Array(amp) * envelope_func(t), carrier_freq=v)
drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
res = solver.solve(
t_span=[0., int(T // dt) * dt],
y0=np.eye(dim, dtype=complex),
Expand Down
1 change: 1 addition & 0 deletions qiskit_dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DYNAMICS_SCIPY_ALIAS,
DYNAMICS_NUMPY,
DYNAMICS_SCIPY,
ArrayLike,
)

from .models.rotating_frame import RotatingFrame
Expand Down
8 changes: 7 additions & 1 deletion qiskit_dynamics/arraylias/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,10 @@
Module for Qiskit Dynamics global NumPy and SciPy aliases.
"""

from .alias import DYNAMICS_NUMPY_ALIAS, DYNAMICS_SCIPY_ALIAS, DYNAMICS_NUMPY, DYNAMICS_SCIPY
from .alias import (
DYNAMICS_NUMPY_ALIAS,
DYNAMICS_SCIPY_ALIAS,
DYNAMICS_NUMPY,
DYNAMICS_SCIPY,
ArrayLike,
)
51 changes: 50 additions & 1 deletion qiskit_dynamics/arraylias/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from arraylias import numpy_alias, scipy_alias

from qiskit import QiskitError

from qiskit_dynamics.array import Array

# global NumPy and SciPy aliases
Expand All @@ -34,4 +36,51 @@
DYNAMICS_SCIPY = DYNAMICS_SCIPY_ALIAS()


ArrayLike = Union[DYNAMICS_NUMPY_ALIAS.registered_types()]
ArrayLike = Union[Union[DYNAMICS_NUMPY_ALIAS.registered_types()], list]


def _preferred_lib(*args, **kwargs):
"""Given a list of args and kwargs with potentially mixed array types, determine the appropriate
library to dispatch to.
For each argument, DYNAMICS_NUMPY_ALIAS.infer_libs is called to infer the library. If all are
"numpy", then it returns "numpy", and if any are "jax", it returns "jax".
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
str
Raises:
QiskitError: if none of the rules apply.
"""
args = list(args) + list(kwargs.values())
if len(args) == 1:
return DYNAMICS_NUMPY_ALIAS.infer_libs(args[0])

lib0 = DYNAMICS_NUMPY_ALIAS.infer_libs(args[0])[0]
lib1 = _preferred_lib(args[1:])[0]

if lib0 == "numpy" and lib1 == "numpy":
return "numpy"
elif lib0 == "jax" or lib1 == "jax":
return "jax"

raise QiskitError("_preferred_lib could not resolve preferred library.")


def _numpy_multi_dispatch(*args, path, **kwargs):
"""Multiple dispatching for NumPy.
Given *args and **kwargs, dispatch the function specified by path, to the array library
specified by _preferred_lib.
Args:
*args: Positional arguments to pass to function specified by path.
path: Path in numpy module structure.
**kwargs: Keyword arguments to pass to function specified by path.
Returns:
Result of evaluating the function at path on the arguments using the preferred library.
"""
lib = _preferred_lib(*args, **kwargs)
return DYNAMICS_NUMPY_ALIAS(like=lib, path=path)(*args, **kwargs)
3 changes: 2 additions & 1 deletion qiskit_dynamics/models/operator_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from qiskit import QiskitError
from qiskit.quantum_info.operators.operator import Operator
from qiskit_dynamics.arraylias.alias import _numpy_multi_dispatch
from qiskit_dynamics.array import Array, wrap
from qiskit_dynamics.type_utils import to_array, to_csr, to_BCOO, vec_commutator, vec_dissipator

Expand Down Expand Up @@ -1357,7 +1358,7 @@ def concatenate_signals(
) -> Array:
"""Concatenate hamiltonian and linblad signals."""
if self._hamiltonian_operators is not None and self._dissipator_operators is not None:
return np.append(ham_sig_vals, dis_sig_vals, axis=-1)
return _numpy_multi_dispatch(ham_sig_vals, dis_sig_vals, path="append", axis=-1)
if self._hamiltonian_operators is not None and self._dissipator_operators is None:
return ham_sig_vals
if self._hamiltonian_operators is None and self._dissipator_operators is not None:
Expand Down
29 changes: 15 additions & 14 deletions qiskit_dynamics/pulse/pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from qiskit import QiskitError

from qiskit_dynamics.array import Array
from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics import ArrayLike

from qiskit_dynamics.signals import DiscreteSignal

try:
Expand Down Expand Up @@ -186,12 +189,10 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]:

# build sample array to append to signal
times = self._dt * (start_sample + np.arange(len(inst_samples)))
samples = inst_samples * np.exp(
Array(
2.0j * np.pi * frequency_shifts[chan] * times
+ 1.0j * phases[chan]
+ 2.0j * np.pi * phase_accumulations[chan]
)
samples = inst_samples * unp.exp(
2.0j * np.pi * frequency_shifts[chan] * times
+ 1.0j * phases[chan]
+ 2.0j * np.pi * phase_accumulations[chan]
)
signals[chan].add_samples(start_sample, samples)

Expand All @@ -202,7 +203,7 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]:
phases[chan] = inst.phase

if isinstance(inst, ShiftFrequency):
frequency_shifts[chan] = frequency_shifts[chan] + Array(inst.frequency)
frequency_shifts[chan] = frequency_shifts[chan] + inst.frequency
phase_accumulations[chan] = (
phase_accumulations[chan] - inst.frequency * start_sample * self._dt
)
Expand Down Expand Up @@ -273,7 +274,7 @@ def get_awg_signals(
new_freq = sig.carrier_freq + if_modulation

samples_i = sig.samples
samples_q = np.imag(samples_i) - 1.0j * np.real(samples_i)
samples_q = unp.imag(samples_i) - 1.0j * unp.real(samples_i)

sig_i = DiscreteSignal(
sig.dt,
Expand Down Expand Up @@ -325,7 +326,7 @@ def _get_channel(self, channel_name: str):
) from error


def get_samples(pulse: SymbolicPulse) -> np.ndarray:
def get_samples(pulse: SymbolicPulse) -> ArrayLike:
"""Return samples filled according to the formula that the pulse
represents and the parameter values it contains.
Expand All @@ -349,8 +350,8 @@ def get_samples(pulse: SymbolicPulse) -> np.ndarray:
args = []
for symbol in sorted(envelope.free_symbols, key=lambda s: s.name):
if symbol.name == "t":
times = Array(np.arange(0, pulse_params["duration"]) + 1 / 2)
args.insert(0, times.data)
times = unp.arange(0, pulse_params["duration"]) + 1 / 2
args.insert(0, times)
continue
try:
args.append(pulse_params[symbol.name])
Expand Down Expand Up @@ -381,11 +382,11 @@ def _lru_cache_expr(expr: sym.Expr, backend) -> Callable:
return sym.lambdify(params, expr, modules=backend)


def _nyquist_warn(frequency_shift: Array, dt: float, channel: str):
def _nyquist_warn(frequency_shift: ArrayLike, dt: float, channel: str):
"""Raise a warning if the frequency shift is above the Nyquist frequency given by ``dt``."""

if (
Array(frequency_shift).backend != "jax" or not isinstance(jnp.array(0), jax.core.Tracer)
isinstance(frequency_shift, (int, float, list, np.ndarray))
or not isinstance(jnp.array(0), jax.core.Tracer)
) and np.abs(frequency_shift) > 0.5 / dt:
warn(
"Due to SetFrequency and ShiftFrequency instructions, the digital carrier frequency "
Expand Down
Loading

0 comments on commit 7333814

Please sign in to comment.