Skip to content

Commit

Permalink
feat: add off_center to erf_square (#1023)
Browse files Browse the repository at this point in the history
  • Loading branch information
yitchen-tim authored Aug 23, 2024
1 parent f87be27 commit a79dccc
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 20 deletions.
20 changes: 15 additions & 5 deletions src/braket/pulse/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def __init__(
length: Union[float, FreeParameterExpression],
width: Union[float, FreeParameterExpression],
sigma: Union[float, FreeParameterExpression],
off_center: Union[float, FreeParameterExpression] = 0,
amplitude: Union[float, FreeParameterExpression] = 1,
zero_at_edges: bool = False,
id: Optional[str] = None,
Expand All @@ -529,6 +530,9 @@ def __init__(
half height of the two edges.
sigma (Union[float, FreeParameterExpression]): A characteristic time of how quickly
the edges rise and fall.
off_center (Union[float, FreeParameterExpression]): Shift the smoothed square waveform
earlier or later in time. When positive, the smoothed square is shifted later
(to the right), otherwise earlier (to the left). Defaults to 0.
amplitude (Union[float, FreeParameterExpression]): The amplitude of the waveform
envelope. Defaults to 1.
zero_at_edges (bool): Whether the waveform is scaled such that it has zero value at the
Expand All @@ -539,23 +543,24 @@ def __init__(
self.length = length
self.width = width
self.sigma = sigma
self.off_center = off_center
self.amplitude = amplitude
self.zero_at_edges = zero_at_edges
self.id = id or _make_identifier_name()

def __repr__(self) -> str:
return (
f"ErfSquareWaveform('id': {self.id}, 'length': {self.length}, "
f"'width': {self.width}, 'sigma': {self.sigma}, 'amplitude': {self.amplitude}, "
f"'zero_at_edges': {self.zero_at_edges})"
f"'width': {self.width}, 'sigma': {self.sigma}, 'off_center': {self.off_center}, "
f"'amplitude': {self.amplitude}, 'zero_at_edges': {self.zero_at_edges})"
)

@property
def parameters(self) -> list[Union[FreeParameterExpression, FreeParameter, float]]:
"""Returns the parameters associated with the object, either unbound free parameter
expressions or bound values.
"""
return [self.length, self.width, self.sigma, self.amplitude]
return [self.length, self.width, self.sigma, self.off_center, self.amplitude]

def bind_values(self, **kwargs: Union[FreeParameter, str]) -> ErfSquareWaveform:
"""Takes in parameters and returns an object with specified parameters
Expand All @@ -571,6 +576,7 @@ def bind_values(self, **kwargs: Union[FreeParameter, str]) -> ErfSquareWaveform:
"length": subs_if_free_parameter(self.length, **kwargs),
"width": subs_if_free_parameter(self.width, **kwargs),
"sigma": subs_if_free_parameter(self.sigma, **kwargs),
"off_center": subs_if_free_parameter(self.off_center, **kwargs),
"amplitude": subs_if_free_parameter(self.amplitude, **kwargs),
"zero_at_edges": self.zero_at_edges,
"id": self.id,
Expand All @@ -582,13 +588,15 @@ def __eq__(self, other: ErfSquareWaveform):
self.length,
self.width,
self.sigma,
self.off_center,
self.amplitude,
self.zero_at_edges,
self.id,
) == (
other.length,
other.width,
other.sigma,
other.off_center,
other.amplitude,
other.zero_at_edges,
other.id,
Expand All @@ -606,6 +614,7 @@ def _to_oqpy_expression(self) -> OQPyExpression:
("length", duration),
("width", duration),
("sigma", duration),
("off_center", duration),
("amplitude", float64),
("zero_at_edges", bool_),
],
Expand All @@ -615,6 +624,7 @@ def _to_oqpy_expression(self) -> OQPyExpression:
self.length,
self.width,
self.sigma,
self.off_center,
self.amplitude,
self.zero_at_edges,
),
Expand All @@ -631,8 +641,8 @@ def sample(self, dt: float) -> np.ndarray:
np.ndarray: The sample amplitudes for this waveform.
"""
sample_range = np.arange(0, self.length, dt)
t1 = (self.length - self.width) / 2
t2 = (self.length + self.width) / 2
t1 = (self.length - self.width) / 2 + self.off_center
t2 = (self.length + self.width) / 2 + self.off_center
samples = (
sp.special.erf((sample_range - t1) / self.sigma)
+ sp.special.erf(-(sample_range - t2) / self.sigma)
Expand Down
41 changes: 39 additions & 2 deletions test/unit_tests/braket/pulse/ast/test_approximation_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def test_play_drag_gaussian_waveforms(port):
def test_play_erf_square_waveforms(port):
frame1 = Frame(frame_id="frame1", port=port, frequency=1e8, phase=0, is_predefined=False)
erf_square_wf_ZaE_False = ErfSquareWaveform(
length=1e-8, width=8e-9, sigma=1e-9, amplitude=0.8, zero_at_edges=False
length=1e-8, width=8e-9, sigma=1e-9, off_center=0.0, amplitude=0.8, zero_at_edges=False
)
pulse_seq = PulseSequence().play(frame1, erf_square_wf_ZaE_False)

Expand Down Expand Up @@ -733,7 +733,7 @@ def test_play_erf_square_waveforms(port):
def test_play_erf_square_waveforms_zero_at_edges(port):
frame1 = Frame(frame_id="frame1", port=port, frequency=1e8, phase=0, is_predefined=False)
erf_square_wf_ZaE_True = ErfSquareWaveform(
length=1e-8, width=8e-9, sigma=1e-9, amplitude=0.8, zero_at_edges=True
length=1e-8, width=8e-9, sigma=1e-9, off_center=0.0, amplitude=0.8, zero_at_edges=True
)
pulse_seq = PulseSequence().play(frame1, erf_square_wf_ZaE_True)

Expand Down Expand Up @@ -767,6 +767,43 @@ def test_play_erf_square_waveforms_zero_at_edges(port):
verify_results(parser, expected_amplitudes, expected_frequencies, expected_phases)


def test_play_erf_square_waveforms_off_center(port):
frame1 = Frame(frame_id="frame1", port=port, frequency=1e8, phase=0, is_predefined=False)
erf_square_wf_ZaE_False = ErfSquareWaveform(
length=1e-8, width=8e-9, sigma=1e-9, off_center=-2e-9, amplitude=0.8, zero_at_edges=False
)
pulse_seq = PulseSequence().play(frame1, erf_square_wf_ZaE_False)

times = np.arange(0, 1e-8, port.dt)
values = np.array(
[
complex(0.7370803285436436),
complex(0.7981289183125405),
complex(0.7999911761342559),
complex(0.8),
complex(0.7999911761342559),
complex(0.7981289183125405),
complex(0.7370803285436436),
complex(0.4000000061669036),
complex(0.0629196837901632),
complex(0.0018710940212660543),
],
dtype=np.complex128,
)

expected_amplitudes = {"frame1": TimeSeries()}
expected_frequencies = {"frame1": TimeSeries()}
expected_phases = {"frame1": TimeSeries()}

for t, v in zip(times, values):
expected_amplitudes["frame1"].put(t, v)
expected_frequencies["frame1"].put(t, 1e8)
expected_phases["frame1"].put(t, 0)

parser = _ApproximationParser(program=pulse_seq._program, frames=to_dict(frame1))
verify_results(parser, expected_amplitudes, expected_frequencies, expected_phases)


def test_barrier_same_dt(port):
frame1 = Frame(frame_id="frame1", port=port, frequency=1e8, phase=0, is_predefined=False)
frame2 = Frame(frame_id="frame2", port=port, frequency=1e8, phase=0, is_predefined=False)
Expand Down
52 changes: 50 additions & 2 deletions test/unit_tests/braket/pulse/test_pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ArbitraryWaveform,
ConstantWaveform,
DragGaussianWaveform,
ErfSquareWaveform,
Frame,
GaussianWaveform,
Port,
Expand Down Expand Up @@ -119,6 +120,16 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
predefined_frame_2,
ArbitraryWaveform([complex(1, 0.4), 0, 0.3, complex(0.1, 0.2)], id="arb_wf"),
)
.play(
predefined_frame_1,
ErfSquareWaveform(
length=FreeParameter("length_es"),
width=FreeParameter("width_es"),
sigma=2e-9,
off_center=8e-9,
id="erf_square_wf",
),
)
.capture_v0(predefined_frame_2)
)
expected_str_unbound = "\n".join(
Expand All @@ -135,6 +146,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
" sigma_dg * 1s, 0.2, 1, false);",
" waveform constant_wf = constant(length_c * 1s, 2.0 + 0.3im);",
" waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};",
" waveform erf_square_wf = erf_square(length_es * 1s, width_es * 1s, 2.0ns,"
" 8.0ns, 1, false);",
" set_frequency(predefined_frame_1, a + 2.0 * c);",
" shift_frequency(predefined_frame_1, a + 2.0 * c);",
" set_phase(predefined_frame_1, a + 2.0 * c);",
Expand All @@ -149,6 +162,7 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
" play(predefined_frame_2, drag_gauss_wf);",
" play(predefined_frame_1, constant_wf);",
" play(predefined_frame_2, arb_wf);",
" play(predefined_frame_1, erf_square_wf);",
" psb[1] = capture_v0(predefined_frame_2);",
"}",
]
Expand All @@ -162,11 +176,29 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
FreeParameter("sigma_g"),
FreeParameter("sigma_dg"),
FreeParameter("length_c"),
FreeParameter("width_es"),
FreeParameter("length_es"),
}
b_bound = pulse_sequence.make_bound_pulse_sequence(
{"c": 2, "length_g": 1e-3, "length_dg": 3e-3, "sigma_dg": 0.4, "length_c": 4e-3}
{
"c": 2,
"length_g": 1e-3,
"length_dg": 3e-3,
"sigma_dg": 0.4,
"length_c": 4e-3,
"length_es": 20e-9,
"width_es": 12e-9,
}
)
b_bound_call = pulse_sequence(
c=2,
length_g=1e-3,
length_dg=3e-3,
sigma_dg=0.4,
length_c=4e-3,
length_es=20e-9,
width_es=12e-9,
)
b_bound_call = pulse_sequence(c=2, length_g=1e-3, length_dg=3e-3, sigma_dg=0.4, length_c=4e-3)
expected_str_b_bound = "\n".join(
[
"OPENQASM 3.0;",
Expand All @@ -177,6 +209,7 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
" waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);",
" waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);",
" waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};",
" waveform erf_square_wf = erf_square(20.0ns, 12.0ns, 2.0ns, 8.0ns, 1, false);",
" set_frequency(predefined_frame_1, a + 4.0);",
" shift_frequency(predefined_frame_1, a + 4.0);",
" set_phase(predefined_frame_1, a + 4.0);",
Expand All @@ -191,6 +224,7 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
" play(predefined_frame_2, drag_gauss_wf);",
" play(predefined_frame_1, constant_wf);",
" play(predefined_frame_2, arb_wf);",
" play(predefined_frame_1, erf_square_wf);",
" psb[1] = capture_v0(predefined_frame_2);",
"}",
]
Expand All @@ -209,6 +243,7 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
" waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);",
" waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);",
" waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};",
" waveform erf_square_wf = erf_square(20.0ns, 12.0ns, 2.0ns, 8.0ns, 1, false);",
" set_frequency(predefined_frame_1, 5.0);",
" shift_frequency(predefined_frame_1, 5.0);",
" set_phase(predefined_frame_1, 5.0);",
Expand All @@ -223,6 +258,7 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
" play(predefined_frame_2, drag_gauss_wf);",
" play(predefined_frame_1, constant_wf);",
" play(predefined_frame_2, arb_wf);",
" play(predefined_frame_1, erf_square_wf);",
" psb[1] = capture_v0(predefined_frame_2);",
"}",
]
Expand Down Expand Up @@ -302,6 +338,16 @@ def test_pulse_sequence_to_ir(predefined_frame_1, predefined_frame_2):
predefined_frame_2,
ArbitraryWaveform([complex(1, 0.4), 0, 0.3, complex(0.1, 0.2)], id="arb_wf"),
)
.play(
predefined_frame_1,
ErfSquareWaveform(
length=32e-9,
width=20e-9,
sigma=2e-9,
off_center=8e-9,
id="erf_square_wf",
),
)
.capture_v0(predefined_frame_2)
)
expected_str = "\n".join(
Expand All @@ -313,6 +359,7 @@ def test_pulse_sequence_to_ir(predefined_frame_1, predefined_frame_2):
" waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);",
" waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);",
" waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};",
" waveform erf_square_wf = erf_square(32.0ns, 20.0ns, 2.0ns, 8.0ns, 1, false);",
" set_frequency(predefined_frame_1, 3000000000.0);",
" shift_frequency(predefined_frame_1, 1000000000.0);",
" set_phase(predefined_frame_1, -0.5);",
Expand All @@ -328,6 +375,7 @@ def test_pulse_sequence_to_ir(predefined_frame_1, predefined_frame_2):
" play(predefined_frame_2, drag_gauss_wf);",
" play(predefined_frame_1, constant_wf);",
" play(predefined_frame_2, arb_wf);",
" play(predefined_frame_1, erf_square_wf);",
" psb[1] = capture_v0(predefined_frame_2);",
"}",
]
Expand Down
Loading

0 comments on commit a79dccc

Please sign in to comment.