Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arraylias integration - Signal class - #269

Merged
merged 45 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
02edfce
arraylias to signal class
to24toro Oct 16, 2023
693e378
potential fix for using test_array_backends
DanPuzzuoli Oct 16, 2023
de2596f
Merge pull request #2 from DanPuzzuoli/potential-test-fix
to24toro Oct 17, 2023
055a086
added test_array_backends to TestSignalsJaxTransformations
to24toro Oct 18, 2023
f8a3b62
changed signals.py
to24toro Oct 18, 2023
fcaaa82
remove unsed module
to24toro Oct 18, 2023
bb5422c
pylint: disable=no-member in test_signals.py
to24toro Oct 19, 2023
47a9de4
added docstrings to setup in test_signals.py
to24toro Oct 19, 2023
941adde
transfer_functions.py
to24toro Oct 20, 2023
f374ff1
Merge branch 'main' into arraylias/signal
DanPuzzuoli Oct 24, 2023
46801f7
changing Signal(Array(x)) to Signal(x) for constant signals
DanPuzzuoli Oct 24, 2023
2dd286b
fixing lanczos jax diag test
DanPuzzuoli Oct 24, 2023
44a935e
fixing JAX perturbative solver tests
DanPuzzuoli Oct 24, 2023
42ff34f
remove pdb
to24toro Oct 25, 2023
d3fd3dc
lint
to24toro Oct 25, 2023
79c45ce
change alias to numpy_alias
to24toro Oct 25, 2023
a3509f5
remove type hint included in ArrayLike
to24toro Oct 25, 2023
2849168
add list to Arraylike
to24toro Oct 25, 2023
921126b
remove list from type hint
to24toro Oct 25, 2023
c0b15c5
fixing JAX perturbative solver test
DanPuzzuoli Oct 25, 2023
9717146
getting pulse tests working through addition of _preferred_lib function
DanPuzzuoli Oct 25, 2023
a6043b3
fixing pulse conversion errors
DanPuzzuoli Oct 25, 2023
09ffd37
fixing remaining errors with simple multiple dispatching function
DanPuzzuoli Oct 25, 2023
3a36f1c
change isinstance(sig, Array)
to24toro Oct 25, 2023
591b533
change isinstance(sig, Array)
to24toro Oct 25, 2023
fdb82e3
modify Array to unp to pass tests
to24toro Oct 26, 2023
3fcec23
pulse_to_signals
to24toro Oct 26, 2023
facd132
lint
to24toro Oct 26, 2023
2d2321c
modify if statement in _nyquist_warn
to24toro Oct 26, 2023
54cf514
getting docs to ubild
DanPuzzuoli Oct 26, 2023
adcbc60
Merge branch 'arraylias/signal' into further-signal-arraylias-fix
to24toro Oct 27, 2023
2ca2866
Merge pull request #3 from DanPuzzuoli/further-signal-arraylias-fix
to24toro Oct 27, 2023
787f0f8
black
to24toro Oct 27, 2023
7171fd4
Merge branch 'main' into arraylias/signal
to24toro Oct 27, 2023
de6aff0
reorganize import
to24toro Oct 27, 2023
3a4d2b1
Update qiskit_dynamics/signals/signals.py
to24toro Oct 30, 2023
94480ea
Update qiskit_dynamics/signals/signals.py
to24toro Oct 30, 2023
5a6b474
Update qiskit_dynamics/signals/signals.py
to24toro Oct 30, 2023
dcbd902
Update qiskit_dynamics/signals/signals.py
to24toro Oct 30, 2023
0d77f8b
Update qiskit_dynamics/signals/signals.py
to24toro Oct 30, 2023
2e1436e
reflect on comments
to24toro Oct 30, 2023
e2f6449
Merge branch 'main' into arraylias/signal
to24toro Oct 30, 2023
29ba659
Merge branch 'main' into arraylias/signal
DanPuzzuoli Oct 30, 2023
619581b
Update qiskit_dynamics/signals/transfer_functions.py
DanPuzzuoli Oct 30, 2023
12c36da
fixing formatting
DanPuzzuoli Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/tutorials/optimizing_pulse_sequence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ example of one.
.. jupyter-execute::

from qiskit_dynamics import DiscreteSignal
from qiskit_dynamics.array import Array
from qiskit_dynamics.signals import Convolution

import jax.numpy as jnp

# define convolution filter
def gaus(t):
sigma = 15
Expand All @@ -128,13 +129,12 @@ example of one.

# define function mapping parameters to signals
def signal_mapping(params):
samples = Array(params)

# map samples into [-1, 1]
bounded_samples = np.arctan(samples) / (np.pi / 2)
bounded_samples = jnp.arctan(params) / (np.pi / 2)

# pad with 0 at beginning
padded_samples = np.append(Array([0], dtype=complex), bounded_samples)
padded_samples = jnp.append(jnp.array([0], dtype=complex), bounded_samples)

# apply filter
output_signal = convolution(DiscreteSignal(dt=1., samples=padded_samples))
Expand Down
6 changes: 3 additions & 3 deletions docs/userguide/how_to_configure_simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ further highlight the benefits of the sparse representation.

static_hamiltonian = 2 * np.pi * v * N + np.pi * anharm * N * (N - np.eye(dim))
drive_hamiltonian = 2 * np.pi * r * (a + adag)
drive_signal = Signal(Array(1.), carrier_freq=v)
drive_signal = Signal(1., carrier_freq=v)

y0 = np.zeros(dim, dtype=complex)
y0[1] = 1.
Expand All @@ -266,7 +266,7 @@ amplitude, and just-in-time compile it using JAX.
)

def dense_func(amp):
drive_signal = Signal(Array(amp), carrier_freq=v)
drive_signal = Signal(amp, carrier_freq=v)
res = solver.solve(
t_span=[0., T],
y0=y0,
Expand All @@ -292,7 +292,7 @@ diagonal, but we explicitly highlight the need for this.
evaluation_mode='sparse')

def sparse_func(amp):
drive_signal = Signal(Array(amp), carrier_freq=v)
drive_signal = Signal(amp, carrier_freq=v)
res = sparse_solver.solve(
t_span=[0., T],
y0=y0,
Expand Down
1 change: 0 additions & 1 deletion docs/userguide/how_to_use_jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ before setting the signals, to ensure the simulation function remains pure.
def sim_function(amp):

# define a constant signal
amp = Array(amp)
signals = [Signal(amp, carrier_freq=w)]

# simulate and return results
Expand Down
2 changes: 1 addition & 1 deletion docs/userguide/how_to_use_pulse_schedule_for_jax_jit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@ JAX-compiled (or more generally, JAX-transformed).
# convert from a pulse schedule to a list of signals
converter = InstructionToSignals(dt, carriers={"d0": w})

return converter.get_signals(schedule)[0].samples.data
return converter.get_signals(schedule)[0].samples

jax.jit(jit_func)(0.4)
4 changes: 2 additions & 2 deletions docs/userguide/perturbative_solvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ these functions gives a sense of the speeds attainable by these solvers.
"""For a given envelope amplitude, simulate the final unitary using the
Dyson solver.
"""
drive_signal = Signal(lambda t: Array(amp) * envelope_func(t), carrier_freq=v)
drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
return dyson_solver.solve(
signals=[drive_signal],
y0=np.eye(dim, dtype=complex),
Expand Down Expand Up @@ -220,7 +220,7 @@ accuracy and simulation speed.

# specify tolerance as an argument to run the simulation at different tolerances
def ode_sim(amp, tol):
drive_signal = Signal(lambda t: Array(amp) * envelope_func(t), carrier_freq=v)
drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
res = solver.solve(
t_span=[0., int(T // dt) * dt],
y0=np.eye(dim, dtype=complex),
Expand Down
1 change: 1 addition & 0 deletions qiskit_dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DYNAMICS_SCIPY_ALIAS,
DYNAMICS_NUMPY,
DYNAMICS_SCIPY,
ArrayLike,
)

from .models.rotating_frame import RotatingFrame
Expand Down
8 changes: 7 additions & 1 deletion qiskit_dynamics/arraylias/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,10 @@
Module for Qiskit Dynamics global NumPy and SciPy aliases.
"""

from .alias import DYNAMICS_NUMPY_ALIAS, DYNAMICS_SCIPY_ALIAS, DYNAMICS_NUMPY, DYNAMICS_SCIPY
from .alias import (
DYNAMICS_NUMPY_ALIAS,
DYNAMICS_SCIPY_ALIAS,
DYNAMICS_NUMPY,
DYNAMICS_SCIPY,
ArrayLike,
)
51 changes: 50 additions & 1 deletion qiskit_dynamics/arraylias/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from arraylias import numpy_alias, scipy_alias

from qiskit import QiskitError

from qiskit_dynamics.array import Array

# global NumPy and SciPy aliases
Expand All @@ -34,4 +36,51 @@
DYNAMICS_SCIPY = DYNAMICS_SCIPY_ALIAS()


ArrayLike = Union[DYNAMICS_NUMPY_ALIAS.registered_types()]
ArrayLike = Union[Union[DYNAMICS_NUMPY_ALIAS.registered_types()], list]


def _preferred_lib(*args, **kwargs):
"""Given a list of args and kwargs with potentially mixed array types, determine the appropriate
library to dispatch to.

For each argument, DYNAMICS_NUMPY_ALIAS.infer_libs is called to infer the library. If all are
"numpy", then it returns "numpy", and if any are "jax", it returns "jax".

Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
str
Raises:
QiskitError: if none of the rules apply.
"""
args = list(args) + list(kwargs.values())
if len(args) == 1:
return DYNAMICS_NUMPY_ALIAS.infer_libs(args[0])

lib0 = DYNAMICS_NUMPY_ALIAS.infer_libs(args[0])[0]
lib1 = _preferred_lib(args[1:])[0]

if lib0 == "numpy" and lib1 == "numpy":
return "numpy"
elif lib0 == "jax" or lib1 == "jax":
return "jax"

raise QiskitError("_preferred_lib could not resolve preferred library.")


def _numpy_multi_dispatch(*args, path, **kwargs):
"""Multiple dispatching for NumPy.

Given *args and **kwargs, dispatch the function specified by path, to the array library
specified by _preferred_lib.

Args:
*args: Positional arguments to pass to function specified by path.
path: Path in numpy module structure.
**kwargs: Keyword arguments to pass to function specified by path.
Returns:
Result of evaluating the function at path on the arguments using the preferred library.
"""
lib = _preferred_lib(*args, **kwargs)
return DYNAMICS_NUMPY_ALIAS(like=lib, path=path)(*args, **kwargs)
3 changes: 2 additions & 1 deletion qiskit_dynamics/models/operator_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from qiskit import QiskitError
from qiskit.quantum_info.operators.operator import Operator
from qiskit_dynamics.arraylias.alias import _numpy_multi_dispatch
from qiskit_dynamics.array import Array, wrap
from qiskit_dynamics.type_utils import to_array, to_csr, to_BCOO, vec_commutator, vec_dissipator

Expand Down Expand Up @@ -1357,7 +1358,7 @@ def concatenate_signals(
) -> Array:
"""Concatenate hamiltonian and linblad signals."""
if self._hamiltonian_operators is not None and self._dissipator_operators is not None:
return np.append(ham_sig_vals, dis_sig_vals, axis=-1)
return _numpy_multi_dispatch(ham_sig_vals, dis_sig_vals, path="append", axis=-1)
if self._hamiltonian_operators is not None and self._dissipator_operators is None:
return ham_sig_vals
if self._hamiltonian_operators is None and self._dissipator_operators is not None:
Expand Down
31 changes: 16 additions & 15 deletions qiskit_dynamics/pulse/pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from qiskit import QiskitError

from qiskit_dynamics.array import Array
from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics import ArrayLike

from qiskit_dynamics.signals import DiscreteSignal

try:
Expand Down Expand Up @@ -186,12 +189,10 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]:

# build sample array to append to signal
times = self._dt * (start_sample + np.arange(len(inst_samples)))
samples = inst_samples * np.exp(
Array(
2.0j * np.pi * frequency_shifts[chan] * times
+ 1.0j * phases[chan]
+ 2.0j * np.pi * phase_accumulations[chan]
)
samples = inst_samples * unp.exp(
2.0j * np.pi * frequency_shifts[chan] * times
+ 1.0j * phases[chan]
+ 2.0j * np.pi * phase_accumulations[chan]
)
signals[chan].add_samples(start_sample, samples)

Expand All @@ -202,7 +203,7 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]:
phases[chan] = inst.phase

if isinstance(inst, ShiftFrequency):
frequency_shifts[chan] = frequency_shifts[chan] + Array(inst.frequency)
frequency_shifts[chan] = frequency_shifts[chan] + inst.frequency
phase_accumulations[chan] = (
phase_accumulations[chan] - inst.frequency * start_sample * self._dt
)
Expand All @@ -226,7 +227,7 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]:
if sig.duration < max_duration:
sig.add_samples(
start_sample=sig.duration,
samples=np.zeros(max_duration - sig.duration, dtype=complex),
samples=unp.zeros(max_duration - sig.duration, dtype=complex),
to24toro marked this conversation as resolved.
Show resolved Hide resolved
)

# filter the channels
Expand Down Expand Up @@ -273,7 +274,7 @@ def get_awg_signals(
new_freq = sig.carrier_freq + if_modulation

samples_i = sig.samples
samples_q = np.imag(samples_i) - 1.0j * np.real(samples_i)
samples_q = unp.imag(samples_i) - 1.0j * unp.real(samples_i)

sig_i = DiscreteSignal(
sig.dt,
Expand Down Expand Up @@ -325,7 +326,7 @@ def _get_channel(self, channel_name: str):
) from error


def get_samples(pulse: SymbolicPulse) -> np.ndarray:
def get_samples(pulse: SymbolicPulse) -> ArrayLike:
"""Return samples filled according to the formula that the pulse
represents and the parameter values it contains.

Expand All @@ -349,8 +350,8 @@ def get_samples(pulse: SymbolicPulse) -> np.ndarray:
args = []
for symbol in sorted(envelope.free_symbols, key=lambda s: s.name):
if symbol.name == "t":
times = Array(np.arange(0, pulse_params["duration"]) + 1 / 2)
args.insert(0, times.data)
times = unp.arange(0, pulse_params["duration"]) + 1 / 2
args.insert(0, times)
continue
try:
args.append(pulse_params[symbol.name])
Expand Down Expand Up @@ -381,11 +382,11 @@ def _lru_cache_expr(expr: sym.Expr, backend) -> Callable:
return sym.lambdify(params, expr, modules=backend)


def _nyquist_warn(frequency_shift: Array, dt: float, channel: str):
def _nyquist_warn(frequency_shift: ArrayLike, dt: float, channel: str):
"""Raise a warning if the frequency shift is above the Nyquist frequency given by ``dt``."""

if (
Array(frequency_shift).backend != "jax" or not isinstance(jnp.array(0), jax.core.Tracer)
isinstance(frequency_shift, (int, float, list, np.ndarray))
or not isinstance(jnp.array(0), jax.core.Tracer)
) and np.abs(frequency_shift) > 0.5 / dt:
warn(
"Due to SetFrequency and ShiftFrequency instructions, the digital carrier frequency "
Expand Down
Loading
Loading