Skip to content

Commit

Permalink
add __repr__ to waveforms
Browse files Browse the repository at this point in the history
jcjaskula-aws committed Sep 28, 2023
1 parent cc5cdf3 commit cebf58e
Showing 3 changed files with 85 additions and 19 deletions.
7 changes: 4 additions & 3 deletions src/braket/circuits/circuit.py
Original file line number Diff line number Diff line change
@@ -57,9 +57,10 @@
from braket.ir.jaqcd import Program as JaqcdProgram
from braket.ir.openqasm import Program as OpenQasmProgram
from braket.ir.openqasm.program_v1 import io_type
from braket.pulse import ArbitraryWaveform, Frame
from braket.pulse.ast.qasm_parser import ast_to_qasm
from braket.pulse.frame import Frame
from braket.pulse.pulse_sequence import PulseSequence, _validate_uniqueness
from braket.pulse.waveforms import Waveform

SubroutineReturn = TypeVar(
"SubroutineReturn", Iterable[Instruction], Instruction, ResultType, Iterable[ResultType]
@@ -1245,7 +1246,7 @@ def _validate_gate_calbrations_uniqueness(
self,
gate_definitions: Dict[Tuple[Gate, QubitSet], PulseSequence],
frames: Dict[Frame],
waveforms: Dict[ArbitraryWaveform],
waveforms: Dict[Waveform],
) -> None:
for key, calibration in gate_definitions.items():
for frame in calibration._frames.values():
@@ -1303,7 +1304,7 @@ def _generate_frame_wf_defcal_declarations(

def _get_frames_waveforms_from_instrs(
self, gate_definitions: Optional[Dict[Tuple[Gate, QubitSet], PulseSequence]]
) -> Tuple[Dict[Frame], Dict[ArbitraryWaveform]]:
) -> Tuple[Dict[Frame], Dict[Waveform]]:
from braket.circuits.gates import PulseGate

frames = {}
19 changes: 19 additions & 0 deletions src/braket/pulse/waveforms.py
Original file line number Diff line number Diff line change
@@ -83,6 +83,9 @@ def __init__(self, amplitudes: List[complex], id: Optional[str] = None):
self.amplitudes = list(amplitudes)
self.id = id or _make_identifier_name()

def __repr__(self) -> str:
return f"ArbitraryWaveform('id': {self.id}, 'amplitudes': {self.amplitudes})"

def __eq__(self, other):
return isinstance(other, ArbitraryWaveform) and (self.amplitudes, self.id) == (
other.amplitudes,
@@ -131,6 +134,9 @@ def __init__(
self.iq = iq
self.id = id or _make_identifier_name()

def __repr__(self) -> str:
return f"ConstantWaveform('id': {self.id}, 'length': {self.length}, 'iq': {self.iq})"

@property
def parameters(self) -> List[Union[FreeParameterExpression, FreeParameter, float]]:
"""Returns the parameters associated with the object, either unbound free parameter
@@ -236,6 +242,13 @@ def __init__(
self.zero_at_edges = zero_at_edges
self.id = id or _make_identifier_name()

def __repr__(self) -> str:
return (
f"DragGaussianWaveform('id': {self.id}, 'length': {self.length}, "
f"'sigma': {self.sigma}, 'beta': {self.beta}, 'amplitude': {self.amplitude}, "
f"'zero_at_edges': {self.zero_at_edges})"
)

@property
def parameters(self) -> List[Union[FreeParameterExpression, FreeParameter, float]]:
"""Returns the parameters associated with the object, either unbound free parameter
@@ -360,6 +373,12 @@ def __init__(
self.zero_at_edges = zero_at_edges
self.id = id or _make_identifier_name()

def __repr__(self) -> str:
return (
f"GaussianWaveform('id': {self.id}, 'length': {self.length}, 'sigma': {self.sigma}, "
f"'amplitude': {self.amplitude}, 'zero_at_edges': {self.zero_at_edges})"
)

@property
def parameters(self) -> List[Union[FreeParameterExpression, FreeParameter, float]]:
"""Returns the parameters associated with the object, either unbound free parameter
78 changes: 62 additions & 16 deletions test/unit_tests/braket/pulse/test_waveforms.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,14 @@ def test_arbitrary_waveform(amps):
assert oq_exp.name == wf.id


def test_arbitrary_waveform_repr():
amps = [1, 4, 5]
id = "arb_wf_x"
wf = ArbitraryWaveform(amps, id)
expected = f"ArbitraryWaveform('id': {wf.id}, 'amplitudes': {wf.amplitudes})"
assert repr(wf) == expected


def test_arbitrary_waveform_default_params():
amps = [1, 4, 5]
wf = ArbitraryWaveform(amps)
@@ -77,6 +85,15 @@ def test_constant_waveform():
_assert_wf_qasm(wf, "waveform const_wf_x = constant(4.0ms, 4);")


def test_constant_waveform_repr():
length = 4e-3
iq = 4
id = "const_wf_x"
wf = ConstantWaveform(length, iq, id)
expected = f"ConstantWaveform('id': {wf.id}, 'length': {wf.length}, 'iq': {wf.iq})"
assert repr(wf) == expected


def test_constant_waveform_default_params():
amps = [1, 4, 5]
wf = ArbitraryWaveform(amps)
@@ -128,6 +145,21 @@ def test_drag_gaussian_waveform():
_assert_wf_qasm(wf, "waveform drag_gauss_wf = drag_gaussian(4.0ns, 300.0ms, 0.6, 0.4, false);")


def test_drag_gaussian_waveform_repr():
length = 4e-9
sigma = 0.3
beta = 0.6
amplitude = 0.4
zero_at_edges = False
id = "drag_gauss_wf"
wf = DragGaussianWaveform(length, sigma, beta, amplitude, zero_at_edges, id)
expected = (
f"DragGaussianWaveform('id': {wf.id}, 'length': {wf.length}, 'sigma': {wf.sigma}, "
f"'beta': {wf.beta}, 'amplitude': {wf.amplitude}, 'zero_at_edges': {wf.zero_at_edges})"
)
assert repr(wf) == expected


def test_drag_gaussian_waveform_default_params():
length = 4e-9
sigma = 0.3
@@ -151,22 +183,6 @@ def test_drag_gaussian_wf_eq():
assert wf != wfc


def test_gaussian_waveform():
length = 4e-9
sigma = 0.3
amplitude = 0.4
zero_at_edges = False
id = "gauss_wf"
wf = GaussianWaveform(length, sigma, amplitude, zero_at_edges, id)
assert wf.id == id
assert wf.zero_at_edges == zero_at_edges
assert wf.amplitude == amplitude
assert wf.sigma == sigma
assert wf.length == length

_assert_wf_qasm(wf, "waveform gauss_wf = gaussian(4.0ns, 300.0ms, 0.4, false);")


def test_drag_gaussian_wf_free_params():
wf = DragGaussianWaveform(
FreeParameter("length_v"),
@@ -205,6 +221,36 @@ def test_drag_gaussian_wf_free_params():
_assert_wf_qasm(wf_3, "waveform d_gauss_wf = drag_gaussian(600.0ms, 400.0ms, 0.2, 0.1, false);")


def test_gaussian_waveform():
length = 4e-9
sigma = 0.3
amplitude = 0.4
zero_at_edges = False
id = "gauss_wf"
wf = GaussianWaveform(length, sigma, amplitude, zero_at_edges, id)
assert wf.id == id
assert wf.zero_at_edges == zero_at_edges
assert wf.amplitude == amplitude
assert wf.sigma == sigma
assert wf.length == length

_assert_wf_qasm(wf, "waveform gauss_wf = gaussian(4.0ns, 300.0ms, 0.4, false);")


def test_gaussian_waveform_repr():
length = 4e-9
sigma = 0.3
amplitude = 0.4
zero_at_edges = False
id = "gauss_wf"
wf = GaussianWaveform(length, sigma, amplitude, zero_at_edges, id)
expected = (
f"GaussianWaveform('id': {wf.id}, 'length': {wf.length}, 'sigma': {wf.sigma}, "
f"'amplitude': {wf.amplitude}, 'zero_at_edges': {wf.zero_at_edges})"
)
assert repr(wf) == expected


def test_gaussian_waveform_default_params():
length = 4e-9
sigma = 0.3

0 comments on commit cebf58e

Please sign in to comment.