From 66fec48bc641507870768b252e2b9a26046c05af Mon Sep 17 00:00:00 2001 From: "Tim (Yi-Ting)" Date: Mon, 29 Jul 2024 13:12:32 -0400 Subject: [PATCH] feat: support erf_square and swap_phases (#1019) Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com> --- src/braket/pulse/__init__.py | 1 + src/braket/pulse/ast/approximation_parser.py | 28 +++ src/braket/pulse/pulse_sequence.py | 20 +++ src/braket/pulse/waveforms.py | 162 ++++++++++++++++++ .../pulse/ast/test_approximation_parser.py | 122 ++++++++++++- .../unit_tests/braket/pulse/test_waveforms.py | 106 +++++++++++- 6 files changed, 437 insertions(+), 2 deletions(-) diff --git a/src/braket/pulse/__init__.py b/src/braket/pulse/__init__.py index 01ef66892..48f7022af 100644 --- a/src/braket/pulse/__init__.py +++ b/src/braket/pulse/__init__.py @@ -18,5 +18,6 @@ ArbitraryWaveform, ConstantWaveform, DragGaussianWaveform, + ErfSquareWaveform, GaussianWaveform, ) diff --git a/src/braket/pulse/ast/approximation_parser.py b/src/braket/pulse/ast/approximation_parser.py index d2dcf65e8..16c9c6726 100644 --- a/src/braket/pulse/ast/approximation_parser.py +++ b/src/braket/pulse/ast/approximation_parser.py @@ -26,6 +26,7 @@ from braket.pulse.waveforms import ( ConstantWaveform, DragGaussianWaveform, + ErfSquareWaveform, GaussianWaveform, Waveform, ) @@ -468,6 +469,20 @@ def set_scale(self, node: ast.FunctionCall, context: _ParseState) -> None: value = self.visit(node.arguments[1], context) context.frame_data[frame].scale = value + def swap_phases(self, node: ast.FunctionCall, context: _ParseState) -> None: + """A 'swap_phases' Function call. + + Args: + node (ast.FunctionCall): The function call node. + context (_ParseState): The parse state. + """ + frame1 = self.visit(node.arguments[0], context) + frame2 = self.visit(node.arguments[1], context) + phase1 = context.frame_data[frame1].phase + phase2 = context.frame_data[frame2].phase + context.frame_data[frame1].phase = phase2 + context.frame_data[frame2].phase = phase1 + def capture_v0(self, node: ast.FunctionCall, context: _ParseState) -> None: """A 'capture_v0' Function call. @@ -549,6 +564,19 @@ def drag_gaussian(self, node: ast.FunctionCall, context: _ParseState) -> Wavefor args = [self.visit(arg, context) for arg in node.arguments] return DragGaussianWaveform(*args) + def erf_square(self, node: ast.FunctionCall, context: _ParseState) -> Waveform: + """A 'erf_square' Waveform Function call. + + Args: + node (ast.FunctionCall): The function call node. + context (_ParseState): The parse state. + + Returns: + Waveform: The waveform object representing the function call. + """ + args = [self.visit(arg, context) for arg in node.arguments] + return ErfSquareWaveform(*args) + def _init_frame_data(frames: dict[str, Frame]) -> dict[str, _FrameState]: frame_states = { diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 9d43127a0..6c5d60444 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -144,6 +144,26 @@ def shift_phase( self._frames[frame.id] = frame return self + def swap_phases( + self, + frame_1: Frame, + frame_2: Frame, + ) -> PulseSequence: + """Adds an instruction to swap the phases between two frames. + + Args: + frame_1 (Frame): First frame for which to swap the phase. + frame_2 (Frame): Second frame for which to swap the phase. + + Returns: + PulseSequence: self, with the instruction added. + """ + _validate_uniqueness(self._frames, [frame_1, frame_2]) + self._program.function_call("swap_phases", [frame_1, frame_2]) + self._frames[frame_1.id] = frame_1 + self._frames[frame_2.id] = frame_2 + return self + def set_scale( self, frame: Frame, scale: Union[float, FreeParameterExpression] ) -> PulseSequence: diff --git a/src/braket/pulse/waveforms.py b/src/braket/pulse/waveforms.py index 915d187a8..582592dc8 100644 --- a/src/braket/pulse/waveforms.py +++ b/src/braket/pulse/waveforms.py @@ -19,6 +19,7 @@ from typing import Optional, Union import numpy as np +import scipy as sp from oqpy import WaveformVar, bool_, complex128, declare_waveform_generator, duration, float64 from oqpy.base import OQPyExpression @@ -501,6 +502,166 @@ def _from_calibration_schema(waveform_json: dict) -> GaussianWaveform: return GaussianWaveform(**waveform_parameters) +class ErfSquareWaveform(Waveform, Parameterizable): + """A square waveform with smoothed edges.""" + + def __init__( + self, + length: Union[float, FreeParameterExpression], + width: Union[float, FreeParameterExpression], + sigma: Union[float, FreeParameterExpression], + amplitude: Union[float, FreeParameterExpression] = 1, + zero_at_edges: bool = False, + id: Optional[str] = None, + ): + r"""Initializes a `ErfSquareWaveform`. + + .. math:: (\text{step}((t-t_1)/sigma) + \text{step}(-(t-t_2)/sigma) - 1) + + where :math:`\text{step}(t)` is the rounded step function defined as + :math:`(erf(t)+1)/2` and :math:`t_1` and :math:`t_2` are the timestamps at the half + height. The waveform is scaled such that its maximum is equal to `amplitude`. + + Args: + length (Union[float, FreeParameterExpression]): Duration (in seconds) from the start + to the end of the waveform. + width (Union[float, FreeParameterExpression]): Duration (in seconds) between the + half height of the two edges. + sigma (Union[float, FreeParameterExpression]): A characteristic time of how quickly + the edges rise and fall. + amplitude (Union[float, FreeParameterExpression]): The amplitude of the waveform + envelope. Defaults to 1. + zero_at_edges (bool): Whether the waveform is scaled such that it has zero value at the + edges. Defaults to False. + id (Optional[str]): The identifier used for declaring this waveform. A random string of + ascii characters is assigned by default. + """ + self.length = length + self.width = width + self.sigma = sigma + self.amplitude = amplitude + self.zero_at_edges = zero_at_edges + self.id = id or _make_identifier_name() + + def __repr__(self) -> str: + return ( + f"ErfSquareWaveform('id': {self.id}, 'length': {self.length}, " + f"'width': {self.width}, 'sigma': {self.sigma}, 'amplitude': {self.amplitude}, " + f"'zero_at_edges': {self.zero_at_edges})" + ) + + @property + def parameters(self) -> list[Union[FreeParameterExpression, FreeParameter, float]]: + """Returns the parameters associated with the object, either unbound free parameter + expressions or bound values. + """ + return [self.length, self.width, self.sigma, self.amplitude] + + def bind_values(self, **kwargs: Union[FreeParameter, str]) -> ErfSquareWaveform: + """Takes in parameters and returns an object with specified parameters + replaced with their values. + + Args: + **kwargs (Union[FreeParameter, str]): Arbitrary keyword arguments. + + Returns: + ErfSquareWaveform: A copy of this waveform with the requested parameters bound. + """ + constructor_kwargs = { + "length": subs_if_free_parameter(self.length, **kwargs), + "width": subs_if_free_parameter(self.width, **kwargs), + "sigma": subs_if_free_parameter(self.sigma, **kwargs), + "amplitude": subs_if_free_parameter(self.amplitude, **kwargs), + "zero_at_edges": self.zero_at_edges, + "id": self.id, + } + return ErfSquareWaveform(**constructor_kwargs) + + def __eq__(self, other: ErfSquareWaveform): + return isinstance(other, ErfSquareWaveform) and ( + self.length, + self.width, + self.sigma, + self.amplitude, + self.zero_at_edges, + self.id, + ) == ( + other.length, + other.width, + other.sigma, + other.amplitude, + other.zero_at_edges, + other.id, + ) + + def _to_oqpy_expression(self) -> OQPyExpression: + """Returns an OQPyExpression defining this waveform. + + Returns: + OQPyExpression: The OQPyExpression. + """ + erf_square_generator = declare_waveform_generator( + "erf_square", + [ + ("length", duration), + ("width", duration), + ("sigma", duration), + ("amplitude", float64), + ("zero_at_edges", bool_), + ], + ) + return WaveformVar( + init_expression=erf_square_generator( + self.length, + self.width, + self.sigma, + self.amplitude, + self.zero_at_edges, + ), + name=self.id, + ) + + def sample(self, dt: float) -> np.ndarray: + """Generates a sample of amplitudes for this Waveform based on the given time resolution. + + Args: + dt (float): The time resolution. + + Returns: + np.ndarray: The sample amplitudes for this waveform. + """ + sample_range = np.arange(0, self.length, dt) + t1 = (self.length - self.width) / 2 + t2 = (self.length + self.width) / 2 + samples = ( + sp.special.erf((sample_range - t1) / self.sigma) + + sp.special.erf(-(sample_range - t2) / self.sigma) + ) / 2 + + mid_waveform_height = sp.special.erf((self.width / 2) / self.sigma) + waveform_bottom = (sp.special.erf(-t1 / self.sigma) + sp.special.erf(t2 / self.sigma)) / 2 + + if self.zero_at_edges: + return ( + (samples - waveform_bottom) + / (mid_waveform_height - waveform_bottom) + * self.amplitude + ) + else: + return samples * self.amplitude / mid_waveform_height + + @staticmethod + def _from_calibration_schema(waveform_json: dict) -> ErfSquareWaveform: + waveform_parameters = {"id": waveform_json["waveformId"]} + for val in waveform_json["arguments"]: + waveform_parameters[val["name"]] = ( + float(val["value"]) + if val["type"] == "float" + else FreeParameterExpression(val["value"]) + ) + return ErfSquareWaveform(**waveform_parameters) + + def _make_identifier_name() -> str: return "".join([random.choice(string.ascii_letters) for _ in range(10)]) # noqa S311 @@ -511,6 +672,7 @@ def _parse_waveform_from_calibration_schema(waveform: dict) -> Waveform: "drag_gaussian": DragGaussianWaveform._from_calibration_schema, "gaussian": GaussianWaveform._from_calibration_schema, "constant": ConstantWaveform._from_calibration_schema, + "erf_square": ErfSquareWaveform._from_calibration_schema, } if "amplitudes" in waveform: waveform["name"] = "arbitrary" diff --git a/test/unit_tests/braket/pulse/ast/test_approximation_parser.py b/test/unit_tests/braket/pulse/ast/test_approximation_parser.py index 56f02aa12..479a3b53e 100644 --- a/test/unit_tests/braket/pulse/ast/test_approximation_parser.py +++ b/test/unit_tests/braket/pulse/ast/test_approximation_parser.py @@ -19,7 +19,13 @@ from openpulse import ast from oqpy import IntVar -from braket.pulse import ArbitraryWaveform, ConstantWaveform, DragGaussianWaveform, GaussianWaveform +from braket.pulse import ( + ArbitraryWaveform, + ConstantWaveform, + DragGaussianWaveform, + ErfSquareWaveform, + GaussianWaveform, +) from braket.pulse.ast.approximation_parser import _ApproximationParser from braket.pulse.frame import Frame from braket.pulse.port import Port @@ -352,6 +358,46 @@ def test_set_shift_phase_beyond_2_pi(port): verify_results(parser, expected_amplitudes, expected_frequencies, expected_phases) +def test_swap_phases(port): + phase1 = 0.12 + phase2 = 0.34 + frequency = 1e8 + frame1 = Frame( + frame_id="frame1", port=port, frequency=frequency, phase=phase1, is_predefined=False + ) + frame2 = Frame( + frame_id="frame2", port=port, frequency=frequency, phase=phase2, is_predefined=False + ) + pulse_seq = PulseSequence().delay([], 10e-9).swap_phases(frame1, frame2).delay([], 10e-9) + expected_amplitudes = {"frame1": TimeSeries(), "frame2": TimeSeries()} + expected_frequencies = {"frame1": TimeSeries(), "frame2": TimeSeries()} + expected_phases = {"frame1": TimeSeries(), "frame2": TimeSeries()} + + # properties of frame1 before swap + expected_amplitudes["frame1"].put(0, 0).put(9e-9, 0) + expected_frequencies["frame1"].put(0, frequency).put(9e-9, frequency) + expected_phases["frame1"].put(0, phase1).put(9e-9, phase1) + + # properties of frame1 after swap + expected_amplitudes["frame1"].put(10e-9, 0).put(19e-9, 0) + expected_frequencies["frame1"].put(10e-9, frequency).put(19e-9, frequency) + expected_phases["frame1"].put(10e-9, phase2).put(19e-9, phase2) + + # properties of frame2 before swap + expected_amplitudes["frame2"].put(0, 0).put(9e-9, 0) + expected_frequencies["frame2"].put(0, frequency).put(9e-9, frequency) + expected_phases["frame2"].put(0, phase2).put(9e-9, phase2) + + # properties of frame2 after swap + expected_amplitudes["frame2"].put(10e-9, 0).put(19e-9, 0) + expected_frequencies["frame2"].put(10e-9, frequency).put(19e-9, frequency) + expected_phases["frame2"].put(10e-9, phase1).put(19e-9, phase1) + + parser = _ApproximationParser(program=pulse_seq._program, frames=to_dict([frame1, frame2])) + + verify_results(parser, expected_amplitudes, expected_frequencies, expected_phases) + + def test_set_shift_frequency(port): frame = Frame(frame_id="frame1", port=port, frequency=1e8, phase=0, is_predefined=False) pulse_seq = ( @@ -647,6 +693,80 @@ def test_play_drag_gaussian_waveforms(port): verify_results(parser, expected_amplitudes, expected_frequencies, expected_phases) +def test_play_erf_square_waveforms(port): + frame1 = Frame(frame_id="frame1", port=port, frequency=1e8, phase=0, is_predefined=False) + erf_square_wf_ZaE_False = ErfSquareWaveform( + length=1e-8, width=8e-9, sigma=1e-9, amplitude=0.8, zero_at_edges=False + ) + pulse_seq = PulseSequence().play(frame1, erf_square_wf_ZaE_False) + + times = np.arange(0, 1e-8, port.dt) + values = np.array( + [ + complex(0.06291968379016318), + complex(0.4000000061669033), + complex(0.7370803285436436), + complex(0.7981289183125405), + complex(0.7999911761342559), + complex(0.8), + complex(0.7999911761342559), + complex(0.7981289183125405), + complex(0.7370803285436436), + complex(0.4000000061669033), + ], + dtype=np.complex128, + ) + + expected_amplitudes = {"frame1": TimeSeries()} + expected_frequencies = {"frame1": TimeSeries()} + expected_phases = {"frame1": TimeSeries()} + + for t, v in zip(times, values): + expected_amplitudes["frame1"].put(t, v) + expected_frequencies["frame1"].put(t, 1e8) + expected_phases["frame1"].put(t, 0) + + parser = _ApproximationParser(program=pulse_seq._program, frames=to_dict(frame1)) + verify_results(parser, expected_amplitudes, expected_frequencies, expected_phases) + + +def test_play_erf_square_waveforms_zero_at_edges(port): + frame1 = Frame(frame_id="frame1", port=port, frequency=1e8, phase=0, is_predefined=False) + erf_square_wf_ZaE_True = ErfSquareWaveform( + length=1e-8, width=8e-9, sigma=1e-9, amplitude=0.8, zero_at_edges=True + ) + pulse_seq = PulseSequence().play(frame1, erf_square_wf_ZaE_True) + + times = np.arange(0, 1e-8, port.dt) + values = np.array( + [ + complex(4.819981832973268e-17), + complex(0.36585464564844294), + complex(0.731709291296886), + complex(0.7979691964131336), + complex(0.7999904228990518), + complex(0.8), + complex(0.7999904228990518), + complex(0.7979691964131336), + complex(0.731709291296886), + complex(0.36585464564844294), + ], + dtype=np.complex128, + ) + + expected_amplitudes = {"frame1": TimeSeries()} + expected_frequencies = {"frame1": TimeSeries()} + expected_phases = {"frame1": TimeSeries()} + + for t, v in zip(times, values): + expected_amplitudes["frame1"].put(t, v) + expected_frequencies["frame1"].put(t, 1e8) + expected_phases["frame1"].put(t, 0) + + parser = _ApproximationParser(program=pulse_seq._program, frames=to_dict(frame1)) + verify_results(parser, expected_amplitudes, expected_frequencies, expected_phases) + + def test_barrier_same_dt(port): frame1 = Frame(frame_id="frame1", port=port, frequency=1e8, phase=0, is_predefined=False) frame2 = Frame(frame_id="frame2", port=port, frequency=1e8, phase=0, is_predefined=False) diff --git a/test/unit_tests/braket/pulse/test_waveforms.py b/test/unit_tests/braket/pulse/test_waveforms.py index 34f989c0b..636cb27af 100644 --- a/test/unit_tests/braket/pulse/test_waveforms.py +++ b/test/unit_tests/braket/pulse/test_waveforms.py @@ -20,7 +20,13 @@ from oqpy import Program from braket.circuits.free_parameter import FreeParameter -from braket.pulse import ArbitraryWaveform, ConstantWaveform, DragGaussianWaveform, GaussianWaveform +from braket.pulse import ( + ArbitraryWaveform, + ConstantWaveform, + DragGaussianWaveform, + ErfSquareWaveform, + GaussianWaveform, +) from braket.pulse.ast.qasm_parser import ast_to_qasm from braket.pulse.waveforms import _parse_waveform_from_calibration_schema @@ -294,6 +300,85 @@ def test_gaussian_wf_free_params(): _assert_wf_qasm(wf_3, "waveform gauss_wf = gaussian(600.0ms, 300.0ms, 0.1, false);") +def test_erf_square_waveform(): + length = 4e-9 + width = 0.3 + sigma = 0.2 + amplitude = 0.4 + zero_at_edges = False + id = "erf_square_wf" + wf = ErfSquareWaveform(length, width, sigma, amplitude, zero_at_edges, id) + assert wf.id == id + assert wf.zero_at_edges == zero_at_edges + assert wf.amplitude == amplitude + assert wf.width == width + assert wf.sigma == sigma + assert wf.length == length + + +def test_erf_square_waveform_repr(): + length = 4e-9 + width = 0.3 + sigma = 0.2 + amplitude = 0.4 + zero_at_edges = False + id = "erf_square_wf" + wf = ErfSquareWaveform(length, width, sigma, amplitude, zero_at_edges, id) + repr(wf) + + +def test_erf_square_waveform_default_params(): + length = 4e-9 + width = 0.3 + sigma = 0.2 + wf = ErfSquareWaveform(length, width, sigma) + assert re.match(r"[A-Za-z]{10}", wf.id) + assert wf.zero_at_edges is False + assert wf.amplitude == 1 + assert wf.width == width + assert wf.sigma == sigma + assert wf.length == length + + +def test_erf_square_wf_eq(): + wf = ErfSquareWaveform(4e-9, 0.3, 0.2, 0.7, True, "wf_es") + wf_2 = ErfSquareWaveform(wf.length, wf.width, wf.sigma, wf.amplitude, wf.zero_at_edges, wf.id) + assert wf_2 == wf + for att in ["length", "width", "sigma", "amplitude", "zero_at_edges", "id"]: + wfc = deepcopy(wf_2) + setattr(wfc, att, "wrong_value") + assert wf != wfc + + +def test_erf_square_wf_free_params(): + wf = ErfSquareWaveform( + FreeParameter("length_v"), + FreeParameter("width_x"), + FreeParameter("sigma_y"), + FreeParameter("amp_z"), + id="erf_square_wf", + ) + assert wf.parameters == [ + FreeParameter("length_v"), + FreeParameter("width_x"), + FreeParameter("sigma_y"), + FreeParameter("amp_z"), + ] + + wf_2 = wf.bind_values(length_v=0.6, width_x=0.4) + assert wf_2.parameters == [0.6, 0.4, FreeParameter("sigma_y"), FreeParameter("amp_z")] + _assert_wf_qasm( + wf_2, + "waveform erf_square_wf = erf_square(600.0ms, 400.0ms, sigma_y * 1s, amp_z, false);", + ) + + wf_3 = wf.bind_values(length_v=0.6, width_x=0.3, sigma_y=0.1) + assert wf_3.parameters == [0.6, 0.3, 0.1, FreeParameter("amp_z")] + _assert_wf_qasm( + wf_3, "waveform erf_square_wf = erf_square(600.0ms, 300.0ms, 100.0ms, amp_z, false);" + ) + + def _assert_wf_qasm(waveform, expected_qasm): p = Program(None) p.declare(waveform._to_oqpy_expression()) @@ -357,6 +442,25 @@ def _assert_wf_qasm(waveform, expected_qasm): }, ConstantWaveform(id="wf_constant", length=2.1, iq=0.23), ), + ( + { + "waveformId": "wf_erf_square_0", + "name": "erf_square", + "arguments": [ + {"name": "length", "value": 6.000000000000001e-8, "type": "float"}, + {"name": "width", "value": 3.000000000000000e-8, "type": "float"}, + {"name": "sigma", "value": 5.000000000060144e-9, "type": "float"}, + {"name": "amplitude", "value": 0.4549282253548838, "type": "float"}, + ], + }, + ErfSquareWaveform( + id="wf_erf_square_0", + length=6.000000000000001e-8, + width=3.000000000000000e-8, + sigma=5.000000000060144e-9, + amplitude=0.4549282253548838, + ), + ), ], ) def test_parse_waveform_from_calibration_schema(waveform_json, waveform):