Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further signal arraylias fix #3

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/tutorials/optimizing_pulse_sequence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ example of one.
.. jupyter-execute::

from qiskit_dynamics import DiscreteSignal
from qiskit_dynamics.array import Array
from qiskit_dynamics.signals import Convolution

import jax.numpy as jnp

# define convolution filter
def gaus(t):
sigma = 15
Expand All @@ -128,13 +129,12 @@ example of one.

# define function mapping parameters to signals
def signal_mapping(params):
samples = Array(params)

# map samples into [-1, 1]
bounded_samples = np.arctan(samples) / (np.pi / 2)
bounded_samples = jnp.arctan(params) / (np.pi / 2)

# pad with 0 at beginning
padded_samples = np.append(Array([0], dtype=complex), bounded_samples)
padded_samples = jnp.append(jnp.array([0], dtype=complex), bounded_samples)

# apply filter
output_signal = convolution(DiscreteSignal(dt=1., samples=padded_samples))
Expand Down
6 changes: 3 additions & 3 deletions docs/userguide/how_to_configure_simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ further highlight the benefits of the sparse representation.

static_hamiltonian = 2 * np.pi * v * N + np.pi * anharm * N * (N - np.eye(dim))
drive_hamiltonian = 2 * np.pi * r * (a + adag)
drive_signal = Signal(Array(1.), carrier_freq=v)
drive_signal = Signal(1., carrier_freq=v)

y0 = np.zeros(dim, dtype=complex)
y0[1] = 1.
Expand All @@ -266,7 +266,7 @@ amplitude, and just-in-time compile it using JAX.
)

def dense_func(amp):
drive_signal = Signal(Array(amp), carrier_freq=v)
drive_signal = Signal(amp, carrier_freq=v)
res = solver.solve(
t_span=[0., T],
y0=y0,
Expand All @@ -292,7 +292,7 @@ diagonal, but we explicitly highlight the need for this.
evaluation_mode='sparse')

def sparse_func(amp):
drive_signal = Signal(Array(amp), carrier_freq=v)
drive_signal = Signal(amp, carrier_freq=v)
res = sparse_solver.solve(
t_span=[0., T],
y0=y0,
Expand Down
1 change: 0 additions & 1 deletion docs/userguide/how_to_use_jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ before setting the signals, to ensure the simulation function remains pure.
def sim_function(amp):

# define a constant signal
amp = Array(amp)
signals = [Signal(amp, carrier_freq=w)]

# simulate and return results
Expand Down
2 changes: 1 addition & 1 deletion docs/userguide/how_to_use_pulse_schedule_for_jax_jit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@ JAX-compiled (or more generally, JAX-transformed).
# convert from a pulse schedule to a list of signals
converter = InstructionToSignals(dt, carriers={"d0": w})

return converter.get_signals(schedule)[0].samples.data
return converter.get_signals(schedule)[0].samples

jax.jit(jit_func)(0.4)
4 changes: 2 additions & 2 deletions docs/userguide/perturbative_solvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ these functions gives a sense of the speeds attainable by these solvers.
"""For a given envelope amplitude, simulate the final unitary using the
Dyson solver.
"""
drive_signal = Signal(lambda t: Array(amp) * envelope_func(t), carrier_freq=v)
drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
return dyson_solver.solve(
signals=[drive_signal],
y0=np.eye(dim, dtype=complex),
Expand Down Expand Up @@ -220,7 +220,7 @@ accuracy and simulation speed.

# specify tolerance as an argument to run the simulation at different tolerances
def ode_sim(amp, tol):
drive_signal = Signal(lambda t: Array(amp) * envelope_func(t), carrier_freq=v)
drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
res = solver.solve(
t_span=[0., int(T // dt) * dt],
y0=np.eye(dim, dtype=complex),
Expand Down
49 changes: 49 additions & 0 deletions qiskit_dynamics/arraylias/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from arraylias import numpy_alias, scipy_alias

from qiskit import QiskitError

from qiskit_dynamics.array import Array

# global NumPy and SciPy aliases
Expand All @@ -35,3 +37,50 @@


ArrayLike = Union[Union[DYNAMICS_NUMPY_ALIAS.registered_types()], list]


def _preferred_lib(*args, **kwargs):
"""Given a list of args and kwargs with potentially mixed array types, determine the appropriate
library to dispatch to.

For each argument, DYNAMICS_NUMPY_ALIAS.infer_libs is called to infer the library. If all are
"numpy", then it returns "numpy", and if any are "jax", it returns "jax".

Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
str
Raises:
QiskitError if none of the rules apply.
"""
args = list(args) + list(kwargs.values())
if len(args) == 1:
return DYNAMICS_NUMPY_ALIAS.infer_libs(args[0])

lib0 = DYNAMICS_NUMPY_ALIAS.infer_libs(args[0])[0]
lib1 = _preferred_lib(args[1:])[0]

if lib0 == "numpy" and lib1 == "numpy":
return "numpy"
elif lib0 == "jax" or lib1 == "jax":
return "jax"

raise QiskitError("_preferred_lib could not resolve preferred library.")


def _numpy_multi_dispatch(*args, path, **kwargs):
"""Multiple dispatching for NumPy.

Given *args and **kwargs, dispatch the function specified by path, to the array library
specified by _preferred_lib.

Args:
*args: Positional arguments to pass to function specified by path.
path: Path in numpy module structure.
**kwargs: Keyword arguments to pass to function specified by path.
Returns:
Result of evaluating the function at path on the arguments using the preferred library.
"""
lib = _preferred_lib(*args, **kwargs)
return DYNAMICS_NUMPY_ALIAS(like=lib, path=path)(*args, **kwargs)
4 changes: 3 additions & 1 deletion qiskit_dynamics/models/operator_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from qiskit import QiskitError
from qiskit.quantum_info.operators.operator import Operator
from qiskit_dynamics.arraylias import DYNAMICS_NUMPY_ALIAS as numpy_alias
from qiskit_dynamics.arraylias.alias import _numpy_multi_dispatch
from qiskit_dynamics.array import Array, wrap
from qiskit_dynamics.type_utils import to_array, to_csr, to_BCOO, vec_commutator, vec_dissipator

Expand Down Expand Up @@ -1357,7 +1359,7 @@ def concatenate_signals(
) -> Array:
"""Concatenate hamiltonian and linblad signals."""
if self._hamiltonian_operators is not None and self._dissipator_operators is not None:
return np.append(ham_sig_vals, dis_sig_vals, axis=-1)
return _numpy_multi_dispatch(ham_sig_vals, dis_sig_vals, path="append", axis=-1)
if self._hamiltonian_operators is not None and self._dissipator_operators is None:
return ham_sig_vals
if self._hamiltonian_operators is None and self._dissipator_operators is not None:
Expand Down
17 changes: 8 additions & 9 deletions qiskit_dynamics/pulse/pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
from qiskit.pulse.exceptions import PulseError
from qiskit.pulse.library import SymbolicPulse
from qiskit import QiskitError

from qiskit_dynamics.array import Array
from qiskit_dynamics.arraylias import ArrayLike
from qiskit_dynamics.arraylias import DYNAMICS_NUMPY as unp
from qiskit_dynamics import ArrayLike
from qiskit_dynamics import DYNAMICS_NUMPY as unp

from qiskit_dynamics.signals import DiscreteSignal

Expand Down Expand Up @@ -187,13 +188,11 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]:
inst_samples = get_samples(inst.pulse)

# build sample array to append to signal
times = self._dt * (start_sample + unp.arange(len(inst_samples)))
times = self._dt * (start_sample + np.arange(len(inst_samples)))
samples = inst_samples * unp.exp(
unp.asarray(
2.0j * np.pi * frequency_shifts[chan] * times
+ 1.0j * phases[chan]
+ 2.0j * np.pi * phase_accumulations[chan]
)
2.0j * np.pi * frequency_shifts[chan] * times
+ 1.0j * phases[chan]
+ 2.0j * np.pi * phase_accumulations[chan]
)
signals[chan].add_samples(start_sample, samples)

Expand All @@ -204,7 +203,7 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]:
phases[chan] = inst.phase

if isinstance(inst, ShiftFrequency):
frequency_shifts[chan] = frequency_shifts[chan] + unp.asarray(inst.frequency)
frequency_shifts[chan] = frequency_shifts[chan] + inst.frequency
phase_accumulations[chan] = (
phase_accumulations[chan] - inst.frequency * start_sample * self._dt
)
Expand Down
5 changes: 3 additions & 2 deletions qiskit_dynamics/signals/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from qiskit_dynamics.arraylias import ArrayLike
from qiskit_dynamics.arraylias import DYNAMICS_NUMPY_ALIAS as numpy_alias
from qiskit_dynamics.arraylias import DYNAMICS_NUMPY as unp
from qiskit_dynamics.arraylias.alias import _numpy_multi_dispatch, _preferred_lib


class Signal:
Expand Down Expand Up @@ -307,7 +308,7 @@ def envelope(t):
-1,
len(self.samples),
)
return numpy_alias(like=idx).asarray(self._padded_samples)[idx]
return numpy_alias(like=_preferred_lib(self._padded_samples, idx)).asarray(self._padded_samples)[idx]

Signal.__init__(self, envelope=envelope, carrier_freq=carrier_freq, phase=phase, name=name)

Expand Down Expand Up @@ -434,7 +435,7 @@ def add_samples(self, start_sample: int, samples: List):
new_samples, unp.repeat(zero_pad, start_sample - len(self.samples))
)

new_samples = unp.append(new_samples, samples)
new_samples = _numpy_multi_dispatch(new_samples, samples, path="append")
self._padded_samples = unp.append(new_samples, zero_pad, axis=0)

def __str__(self) -> str:
Expand Down
5 changes: 3 additions & 2 deletions qiskit_dynamics/signals/transfer_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from qiskit import QiskitError
from qiskit_dynamics.arraylias import DYNAMICS_NUMPY as unp
from qiskit_dynamics.arraylias.alias import _numpy_multi_dispatch

from .signals import Signal, DiscreteSignal

Expand Down Expand Up @@ -118,10 +119,10 @@ def _apply(self, signal: Signal) -> Signal:
# Perform a discrete time convolution.
dt = signal.dt
func_samples = unp.asarray([self._func(dt * i) for i in range(signal.duration)])
func_samples = func_samples / sum(func_samples)
func_samples = func_samples / unp.sum(func_samples)
sig_samples = signal(dt * unp.arange(signal.duration))

convoluted_samples = list(unp.convolve(func_samples, sig_samples))
convoluted_samples = _numpy_multi_dispatch(func_samples, sig_samples, path="convolve")#unp.convolve(func_samples, sig_samples)

return DiscreteSignal(dt, convoluted_samples, carrier_freq=0.0, phase=0.0)
else:
Expand Down
6 changes: 2 additions & 4 deletions test/dynamics/solvers/test_dyson_magnus_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from qiskit_dynamics import Signal, Solver, DysonSolver, MagnusSolver
from qiskit_dynamics.array import Array
from qiskit_dynamics.arraylias import DYNAMICS_NUMPY as unp
from qiskit_dynamics import DYNAMICS_NUMPY as unp

from qiskit_dynamics.solvers.perturbative_solvers.expansion_model import (
_construct_DCT,
Expand Down Expand Up @@ -95,9 +95,7 @@ def gaussian(amp, sig, t0, t):
T = 7 * sig # end of signal

# Function to define gaussian envelope, using gaussian wave function
gaussian_envelope = lambda t: gaussian(
unp.asarray(amp), unp.asarray(sig), unp.asarray(t0), unp.asarray(t)
)
gaussian_envelope = lambda t: gaussian(amp, sig, t0, t)

obj.gauss_signal = Signal(gaussian_envelope, carrier_freq=5.0)

Expand Down