diff --git a/src/braket/circuits/result_type.py b/src/braket/circuits/result_type.py index 3e9f0dfad..a7ce2436c 100644 --- a/src/braket/circuits/result_type.py +++ b/src/braket/circuits/result_type.py @@ -209,32 +209,32 @@ def __init__( super().__init__(ascii_symbols) self._observable = observable self._target = QubitSet(target) - if not self._target: - if self._observable.qubit_count != 1: - raise ValueError( - f"Observable {self._observable} must only operate on 1 qubit for target=None" - ) - elif isinstance(observable, Sum): # nested target - if len(target) != len(observable.summands): - raise ValueError( - "Sum observable's target shape must be a nested list where each term's " - "target length is equal to the observable term's qubits count." - ) - self._target = [QubitSet(term_target) for term_target in target] - for term_target, obs in zip(self._target, observable.summands): - if obs.qubit_count != len(term_target): + if self._target: + if isinstance(observable, Sum): # nested target + if len(target) != len(observable.summands): raise ValueError( "Sum observable's target shape must be a nested list where each term's " "target length is equal to the observable term's qubits count." ) - elif self._observable.qubit_count != len(self._target): - raise ValueError( - f"Observable's qubit count {self._observable.qubit_count} and " - f"the size of the target qubit set {self._target} must be equal" - ) - elif self._observable.qubit_count != len(self.ascii_symbols): + self._target = [QubitSet(term_target) for term_target in target] + for term_target, obs in zip(self._target, observable.summands): + if obs.qubit_count != len(term_target): + raise ValueError( + "Sum observable's target shape must be a nested list where each term's " + "target length is equal to the observable term's qubits count." + ) + elif self._observable.qubit_count != len(self._target): + raise ValueError( + f"Observable's qubit count {self._observable.qubit_count} and " + f"the size of the target qubit set {self._target} must be equal" + ) + elif self._observable.qubit_count != len(self.ascii_symbols): + raise ValueError( + "Observable's qubit count and the number of ASCII symbols must be equal" + ) + elif (not self._observable.targets) and self._observable.qubit_count != 1: raise ValueError( - "Observable's qubit count and the number of ASCII symbols must be equal" + f"Observable {self._observable} must only operate on 1 qubit for target=None" ) @property diff --git a/src/braket/circuits/result_types.py b/src/braket/circuits/result_types.py index f682b9ba1..8af36ab4a 100644 --- a/src/braket/circuits/result_types.py +++ b/src/braket/circuits/result_types.py @@ -14,14 +14,12 @@ from __future__ import annotations import re -from functools import reduce from typing import Union import braket.ir.jaqcd as ir from braket.circuits import circuit from braket.circuits.free_parameter import FreeParameter from braket.circuits.observable import Observable -from braket.circuits.observables import Sum from braket.circuits.result_type import ( ObservableParameterResultType, ObservableResultType, @@ -210,11 +208,7 @@ def __init__( >>> parameters=["alpha", "beta"], >>> ) """ - if isinstance(observable, Sum): - target_qubits = reduce(QubitSet.union, map(QubitSet, target), QubitSet()) - else: - target_qubits = QubitSet(target) - + target_qubits = QubitSet(target if target is not None else observable.targets) super().__init__( ascii_symbols=[f"AdjointGradient({observable.ascii_symbols[0]})"] * len(target_qubits), observable=observable, diff --git a/test/unit_tests/braket/circuits/test_result_types.py b/test/unit_tests/braket/circuits/test_result_types.py index 5f76eeaf9..422fed7a4 100644 --- a/test/unit_tests/braket/circuits/test_result_types.py +++ b/test/unit_tests/braket/circuits/test_result_types.py @@ -272,6 +272,25 @@ def test_ir_result_level(testclass, subroutine_name, irclass, input, ir_input): "#pragma braket result adjoint_gradient expectation(hermitian([[1+0im, 0im], " "[0im, 1+0im]]) q[0]) all", ), + ( + ResultType.AdjointGradient( + Observable.H(0) @ Observable.I(1) + 2 * Observable.Z(2), + parameters=[FreeParameter("alpha"), "beta", FreeParameter("gamma")], + ), + OpenQASMSerializationProperties(qubit_reference_type=QubitReferenceType.VIRTUAL), + "#pragma braket result adjoint_gradient expectation(h(q[0]) @ i(q[1]) + 2 * z(q[2])) " + "alpha, beta, gamma", + ), + ( + ResultType.AdjointGradient( + Observable.H(0) @ Observable.I(1) + 2 * Observable.Z(2), + target=[[3, 4], [5]], + parameters=[FreeParameter("alpha"), "beta", FreeParameter("gamma")], + ), + OpenQASMSerializationProperties(qubit_reference_type=QubitReferenceType.VIRTUAL), + "#pragma braket result adjoint_gradient expectation(h(q[3]) @ i(q[4]) + 2 * z(q[5])) " + "alpha, beta, gamma", + ), ], ) def test_result_to_ir_openqasm(result_type, serialization_properties, expected_ir):