From c84666010daa2a68322a26e133e583a921b30dbf Mon Sep 17 00:00:00 2001 From: "Daniel J. Egger" <38065505+eggerdj@users.noreply.github.com> Date: Tue, 8 Feb 2022 21:45:13 +0100 Subject: [PATCH] Added channel filtering to pulse converter (#59) --- docs/tutorials/qiskit_pulse.rst | 2 +- qiskit_dynamics/pulse/pulse_to_signals.py | 110 ++++++++++++++---- requirements-dev.txt | 1 + .../dynamics/signals/test_pulse_to_signals.py | 80 ++++++++++++- 4 files changed, 165 insertions(+), 28 deletions(-) diff --git a/docs/tutorials/qiskit_pulse.rst b/docs/tutorials/qiskit_pulse.rst index fe91b2547..c22d48194 100644 --- a/docs/tutorials/qiskit_pulse.rst +++ b/docs/tutorials/qiskit_pulse.rst @@ -85,7 +85,7 @@ virtual ``Z`` gate is applied. plt.rcParams["font.size"] = 16 - converter = InstructionToSignals(dt, carriers=[w]) + converter = InstructionToSignals(dt, carriers={"d0": w}) signals = converter.get_signals(xp) fig, axs = plt.subplots(1, 2, figsize=(14, 4.5)) diff --git a/qiskit_dynamics/pulse/pulse_to_signals.py b/qiskit_dynamics/pulse/pulse_to_signals.py index c2c7abc66..bc6040bda 100644 --- a/qiskit_dynamics/pulse/pulse_to_signals.py +++ b/qiskit_dynamics/pulse/pulse_to_signals.py @@ -14,7 +14,7 @@ Pulse schedule to Signals converter. """ -from typing import List +from typing import Dict, List, Optional import numpy as np from qiskit.pulse import ( @@ -25,8 +25,13 @@ ShiftFrequency, SetFrequency, Waveform, + MeasureChannel, + DriveChannel, + ControlChannel, + AcquireChannel, ) from qiskit import QiskitError + from qiskit_dynamics.signals import DiscreteSignal @@ -35,51 +40,70 @@ class InstructionToSignals: The :class:`InstructionsToSignals` class converts a pulse schedule to a list of signals that can be given to a model. This conversion is done by calling - the :meth:`get_signals` method on a schedule. + the :meth:`get_signals` method on a schedule. The converter applies to instances + of :class:`Schedule`. Instances of :class:`ScheduleBlock` must first be + converted to :class:`Schedule` using the :meth:`block_to_schedule` in + Qiskit pulse. + + The converter can be initialized + with the optional arguments ``carriers`` and ``channels``. These arguments + change the returned signals of :meth:`get_signals`. When ``channels`` is given + then only the signals specified by name in ``channels`` are returned. The + ``carriers`` dictionary allows the user to specify the carrier frequency of + the channels. Here, the keys are the channel name, e.g. ``d12`` for drive channel + number 12, and the values are the corresponding frequency. If a channel is not + present in ``carriers`` it is assumed that the carrier frequency is zero. """ - def __init__(self, dt: float, carriers: List[float] = None): + def __init__( + self, + dt: float, + carriers: Optional[Dict[str, float]] = None, + channels: Optional[List[str]] = None, + ): """Initialize pulse schedule to signals converter. Args: - dt: length of the samples. This is required by the converter as pulse + dt: Length of the samples. This is required by the converter as pulse schedule are specified in units of dt and typically do not carry the value of dt with them. - carriers: a list of carrier frequencies. If it is not None there - must be at least as many carrier frequencies as there are - channels in the schedules that will be converted. + carriers: A dict of carrier frequencies. The keys are the names of the channels + and the values are the corresponding carrier frequency. + channels: A list of channels that the :meth:`get_signals` method should return. + This argument will cause :meth:`get_signals` to return the signals in the + same order as the channels. Channels present in the schedule but absent + from channels will not be included in the returned object. If None is given + (the default) then all channels present in the pulse schedule are returned. """ self._dt = dt - self._carriers = carriers + self._channels = channels + self._carriers = carriers or {} def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]: """ Args: - schedule: The schedule to represent in terms of signals. + schedule: The schedule to represent in terms of signals. Instances of + :class:`ScheduleBlock` must first be converted to :class:`Schedule` + using the :meth:`block_to_schedule` in Qiskit pulse. Returns: a list of piecewise constant signals. - - Raises: - qiskit.QiskitError: if not enough frequencies supplied """ - if self._carriers and len(self._carriers) < len(schedule.channels): - raise QiskitError("Not enough carrier frequencies supplied.") - signals, phases, frequency_shifts = {}, {}, {} - for idx, chan in enumerate(schedule.channels): - if self._carriers: - carrier_freq = self._carriers[idx] - else: - carrier_freq = 0.0 + if self._channels is not None: + schedule = schedule.filter(channels=[self._get_channel(ch) for ch in self._channels]) + for idx, chan in enumerate(schedule.channels): phases[chan.name] = 0.0 frequency_shifts[chan.name] = 0.0 signals[chan.name] = DiscreteSignal( - samples=[], dt=self._dt, name=chan.name, carrier_freq=carrier_freq + samples=[], + dt=self._dt, + name=chan.name, + carrier_freq=self._carriers.get(chan.name, 0.0), ) for start_sample, inst in schedule.instructions: @@ -129,7 +153,19 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]: samples=np.zeros(max_duration - sig.duration, dtype=complex), ) - return list(signals.values()) + # filter the channels + if self._channels is None: + return list(signals.values()) + + return_signals = [] + for chan_name in self._channels: + signal = signals.get( + chan_name, DiscreteSignal(samples=[], dt=self._dt, name=chan_name, carrier_freq=0.0) + ) + + return_signals.append(signal) + + return return_signals @staticmethod def get_awg_signals( @@ -150,7 +186,7 @@ def get_awg_signals( Args: signals: A list of signals for which to create I and Q. if_modulation: The intermediate frequency with which the AWG modulates the pulse - envelopes. + envelopes. Returns: iq signals: A list of signals which is twice as long as the input list of signals. @@ -184,3 +220,31 @@ def get_awg_signals( new_signals += [sig_i, sig_q] return new_signals + + def _get_channel(self, channel_name: str): + """Return the channel corresponding to the given name.""" + + try: + prefix = channel_name[0] + index = int(channel_name[1:]) + + if prefix == "d": + return DriveChannel(index) + + if prefix == "m": + return MeasureChannel(index) + + if prefix == "u": + return ControlChannel(index) + + if prefix == "a": + return AcquireChannel(index) + + raise QiskitError( + f"Unsupported channel name {channel_name} in {self.__class__.__name__}" + ) + + except (KeyError, IndexError, ValueError) as error: + raise QiskitError( + f"Invalid channel name {channel_name} given to {self.__class__.__name__}." + ) from error diff --git a/requirements-dev.txt b/requirements-dev.txt index e660ea254..e7f79f6cf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,4 +9,5 @@ pygments>=2.4 reno>=3.4.0 nbsphinx qutip +ddt~=1.4.2 matplotlib>=3.3.0 diff --git a/test/dynamics/signals/test_pulse_to_signals.py b/test/dynamics/signals/test_pulse_to_signals.py index 8cc285e29..6661b9d02 100644 --- a/test/dynamics/signals/test_pulse_to_signals.py +++ b/test/dynamics/signals/test_pulse_to_signals.py @@ -13,11 +13,15 @@ Tests to convert from pulse schedules to signals. """ +from ddt import ddt, data, unpack import numpy as np +import qiskit.pulse as pulse from qiskit.pulse import ( Schedule, DriveChannel, + ControlChannel, + MeasureChannel, Play, Drag, ShiftFrequency, @@ -28,6 +32,9 @@ Constant, Waveform, ) +from qiskit.pulse.transforms.canonicalization import block_to_schedule +from qiskit import QiskitError + from qiskit_dynamics.pulse import InstructionToSignals from qiskit_dynamics.signals import DiscreteSignal @@ -82,7 +89,7 @@ def test_carriers_and_dt(self): sched = Schedule(name="Schedule") sched += Play(Gaussian(duration=20, amp=0.5, sigma=4), DriveChannel(0)) - converter = InstructionToSignals(dt=0.222, carriers=[5.5e9]) + converter = InstructionToSignals(dt=0.222, carriers={"d0": 5.5e9}) signals = converter.get_signals(sched) self.assertEqual(signals[0].carrier_freq, 5.5e9) @@ -96,7 +103,7 @@ def test_shift_frequency(self): sched += ShiftFrequency(1.0, DriveChannel(0)) sched += Play(Constant(duration=10, amp=1.0), DriveChannel(0)) - converter = InstructionToSignals(dt=0.222, carriers=[5.0]) + converter = InstructionToSignals(dt=0.222, carriers={"d0": 5.0}) signals = converter.get_signals(sched) for idx in range(10): @@ -109,7 +116,7 @@ def test_set_frequency(self): sched += SetFrequency(4.0, DriveChannel(0)) sched += Play(Constant(duration=10, amp=1.0), DriveChannel(0)) - converter = InstructionToSignals(dt=0.222, carriers=[5.0]) + converter = InstructionToSignals(dt=0.222, carriers={"d0": 5.0}) signals = converter.get_signals(sched) for idx in range(10): @@ -122,7 +129,7 @@ def test_uneven_pulse_length(self): schedule |= Play(Waveform(np.ones(10)), DriveChannel(0)) schedule += Play(Constant(20, 1), DriveChannel(1)) - converter = InstructionToSignals(dt=0.1, carriers=[2.0, 3.0]) + converter = InstructionToSignals(dt=0.1, carriers={"d0": 2.0, "d1": 3.0}) signals = converter.get_signals(schedule) @@ -134,3 +141,68 @@ def test_uneven_pulse_length(self): self.assertTrue(signals[0].carrier_freq == 2.0) self.assertTrue(signals[1].carrier_freq == 3.0) + + +@ddt +class TestPulseToSignalsFiltering(QiskitDynamicsTestCase): + """Test the extraction of signals when specifying channels.""" + + def setUp(self): + """Setup the tests.""" + + super().setUp() + + # Drags on all qubits, then two CRs, then readout all qubits. + with pulse.build(name="test schedule") as schedule: + with pulse.align_sequential(): + with pulse.align_left(): + for chan_idx in [0, 1, 2, 3]: + pulse.play(Drag(160, 0.5, 40, 0.1), DriveChannel(chan_idx)) + + with pulse.align_sequential(): + for chan_idx in [0, 1]: + pulse.play(GaussianSquare(660, 0.2, 40, 500), ControlChannel(chan_idx)) + + with pulse.align_left(): + for chan_idx in [0, 1, 2, 3]: + pulse.play(GaussianSquare(660, 0.2, 40, 500), MeasureChannel(chan_idx)) + + self._schedule = block_to_schedule(schedule) + + @unpack + @data( + ({"d0": 5.0, "d2": 5.1, "u0": 5.0, "u1": 5.1}, ["d0", "d2", "u0", "u1"]), + ({"m0": 5.0, "m1": 5.1, "m2": 5.0, "m3": 5.1}, ["m0", "m1", "m2", "m3"]), + ({"m0": 5.0, "m1": 5.1, "d0": 5.0, "d1": 5.1}, ["m0", "m1", "d0", "d1"]), + ({"d1": 5.0}, ["d1"]), + ({"d123": 5.0}, ["d123"]), + ) + def test_channel_combinations(self, carriers, channels): + """Test that we can filter out channels in the right order and number.""" + + converter = InstructionToSignals(dt=0.222, carriers=carriers, channels=channels) + + signals = converter.get_signals(self._schedule) + + self.assertEqual(len(signals), len(channels)) + for idx, chan_name in enumerate(channels): + self.assertEqual(signals[idx].name, chan_name) + + def test_empty_signal(self): + """Test that requesting a channel that is not in the schedule gives and empty signal.""" + + converter = InstructionToSignals(dt=0.222, carriers={"d123": 1.0}, channels=["d123"]) + + signals = converter.get_signals(self._schedule) + + self.assertEqual(len(signals), 1) + self.assertEqual(signals[0].duration, 0) + + @data("123", "s123", "", "d") + def test_get_channel_raise(self, channel_name): + """Test that getting channel instances works well.""" + + converter = InstructionToSignals(dt=0.222) + + with self.assertRaises(QiskitError): + converter._get_channel(channel_name)