Skip to content

Commit

Permalink
fix: Use observable targets for targetless results (#1025)
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 authored Aug 28, 2024
1 parent 93f450b commit 8f4e88f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
42 changes: 21 additions & 21 deletions src/braket/circuits/result_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,32 +209,32 @@ def __init__(
super().__init__(ascii_symbols)
self._observable = observable
self._target = QubitSet(target)
if not self._target:
if self._observable.qubit_count != 1:
raise ValueError(
f"Observable {self._observable} must only operate on 1 qubit for target=None"
)
elif isinstance(observable, Sum): # nested target
if len(target) != len(observable.summands):
raise ValueError(
"Sum observable's target shape must be a nested list where each term's "
"target length is equal to the observable term's qubits count."
)
self._target = [QubitSet(term_target) for term_target in target]
for term_target, obs in zip(self._target, observable.summands):
if obs.qubit_count != len(term_target):
if self._target:
if isinstance(observable, Sum): # nested target
if len(target) != len(observable.summands):
raise ValueError(
"Sum observable's target shape must be a nested list where each term's "
"target length is equal to the observable term's qubits count."
)
elif self._observable.qubit_count != len(self._target):
raise ValueError(
f"Observable's qubit count {self._observable.qubit_count} and "
f"the size of the target qubit set {self._target} must be equal"
)
elif self._observable.qubit_count != len(self.ascii_symbols):
self._target = [QubitSet(term_target) for term_target in target]
for term_target, obs in zip(self._target, observable.summands):
if obs.qubit_count != len(term_target):
raise ValueError(
"Sum observable's target shape must be a nested list where each term's "
"target length is equal to the observable term's qubits count."
)
elif self._observable.qubit_count != len(self._target):
raise ValueError(
f"Observable's qubit count {self._observable.qubit_count} and "
f"the size of the target qubit set {self._target} must be equal"
)
elif self._observable.qubit_count != len(self.ascii_symbols):
raise ValueError(
"Observable's qubit count and the number of ASCII symbols must be equal"
)
elif (not self._observable.targets) and self._observable.qubit_count != 1:
raise ValueError(
"Observable's qubit count and the number of ASCII symbols must be equal"
f"Observable {self._observable} must only operate on 1 qubit for target=None"
)

@property
Expand Down
8 changes: 1 addition & 7 deletions src/braket/circuits/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
from __future__ import annotations

import re
from functools import reduce
from typing import Union

import braket.ir.jaqcd as ir
from braket.circuits import circuit
from braket.circuits.free_parameter import FreeParameter
from braket.circuits.observable import Observable
from braket.circuits.observables import Sum
from braket.circuits.result_type import (
ObservableParameterResultType,
ObservableResultType,
Expand Down Expand Up @@ -210,11 +208,7 @@ def __init__(
>>> parameters=["alpha", "beta"],
>>> )
"""
if isinstance(observable, Sum):
target_qubits = reduce(QubitSet.union, map(QubitSet, target), QubitSet())
else:
target_qubits = QubitSet(target)

target_qubits = QubitSet(target if target is not None else observable.targets)
super().__init__(
ascii_symbols=[f"AdjointGradient({observable.ascii_symbols[0]})"] * len(target_qubits),
observable=observable,
Expand Down
19 changes: 19 additions & 0 deletions test/unit_tests/braket/circuits/test_result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,25 @@ def test_ir_result_level(testclass, subroutine_name, irclass, input, ir_input):
"#pragma braket result adjoint_gradient expectation(hermitian([[1+0im, 0im], "
"[0im, 1+0im]]) q[0]) all",
),
(
ResultType.AdjointGradient(
Observable.H(0) @ Observable.I(1) + 2 * Observable.Z(2),
parameters=[FreeParameter("alpha"), "beta", FreeParameter("gamma")],
),
OpenQASMSerializationProperties(qubit_reference_type=QubitReferenceType.VIRTUAL),
"#pragma braket result adjoint_gradient expectation(h(q[0]) @ i(q[1]) + 2 * z(q[2])) "
"alpha, beta, gamma",
),
(
ResultType.AdjointGradient(
Observable.H(0) @ Observable.I(1) + 2 * Observable.Z(2),
target=[[3, 4], [5]],
parameters=[FreeParameter("alpha"), "beta", FreeParameter("gamma")],
),
OpenQASMSerializationProperties(qubit_reference_type=QubitReferenceType.VIRTUAL),
"#pragma braket result adjoint_gradient expectation(h(q[3]) @ i(q[4]) + 2 * z(q[5])) "
"alpha, beta, gamma",
),
],
)
def test_result_to_ir_openqasm(result_type, serialization_properties, expected_ir):
Expand Down

0 comments on commit 8f4e88f

Please sign in to comment.