Skip to content

Commit

Permalink
feat: Allow early qubit binding of observables (#1022)
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 authored Aug 20, 2024
1 parent ee63777 commit b095710
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 133 deletions.
10 changes: 5 additions & 5 deletions examples/bell_result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from braket.circuits import Circuit, Observable
from braket.circuits import Circuit, observables
from braket.devices import LocalSimulator

device = LocalSimulator()
Expand All @@ -24,7 +24,7 @@
.h(0)
.cnot(0, 1)
.probability(target=[0])
.expectation(observable=Observable.Z(), target=[1])
.expectation(observable=observables.Z(1))
.amplitude(state=["00"])
.state_vector()
)
Expand All @@ -45,9 +45,9 @@
Circuit()
.h(0)
.cnot(0, 1)
.expectation(observable=Observable.Y() @ Observable.X(), target=[0, 1])
.variance(observable=Observable.Y() @ Observable.X(), target=[0, 1])
.sample(observable=Observable.Y() @ Observable.X(), target=[0, 1])
.expectation(observable=observables.Y(0) @ observables.X(1))
.variance(observable=observables.Y(0) @ observables.X(1))
.sample(observable=observables.Y(0) @ observables.X(1))
)

# When shots>0 for a simulator, probability, expectation, variance are calculated from measurements
Expand Down
4 changes: 2 additions & 2 deletions examples/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.

from braket.aws import AwsDevice
from braket.circuits import Circuit, FreeParameter, Observable
from braket.circuits import Circuit, FreeParameter, observables
from braket.devices import Devices
from braket.jobs import get_job_device_arn, hybrid_job
from braket.jobs.metrics import log_metric
Expand All @@ -34,7 +34,7 @@ def run_hybrid_job(num_tasks=1):
circ = Circuit()
circ.rx(0, FreeParameter("theta"))
circ.cnot(0, 1)
circ.expectation(observable=Observable.X(), target=0)
circ.expectation(observable=observables.X(0))

# initial parameter
theta = 0.0
Expand Down
4 changes: 2 additions & 2 deletions examples/hybrid_job_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


from braket.aws import AwsDevice, AwsQuantumJob
from braket.circuits import Circuit, FreeParameter, Observable
from braket.circuits import Circuit, FreeParameter, observables
from braket.devices import Devices
from braket.jobs import get_job_device_arn, save_job_result
from braket.jobs.metrics import log_metric
Expand All @@ -27,7 +27,7 @@ def run_hybrid_job(num_tasks: int):
circ = Circuit()
circ.rx(0, FreeParameter("theta"))
circ.cnot(0, 1)
circ.expectation(observable=Observable.X(), target=0)
circ.expectation(observable=observables.X(0))

# initial parameter
theta = 0.0
Expand Down
4 changes: 2 additions & 2 deletions examples/local_noise_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from braket.circuits import Circuit, Noise
from braket.circuits import Circuit, noises
from braket.devices import LocalSimulator

device = LocalSimulator("braket_dm")
Expand All @@ -23,7 +23,7 @@


circuit = Circuit().x(0).x(1)
noise = Noise.BitFlip(probability=0.1)
noise = noises.BitFlip(probability=0.1)
circuit.apply_gate_noise(noise)
print("Second example: ")
print(circuit)
Expand Down
56 changes: 42 additions & 14 deletions src/braket/circuits/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import numbers
from collections.abc import Sequence
from copy import deepcopy
from typing import Union

import numpy as np

Expand All @@ -27,7 +26,7 @@
OpenQASMSerializationProperties,
SerializationProperties,
)
from braket.registers.qubit_set import QubitSet
from braket.registers import QubitInput, QubitSet, QubitSetInput


class Observable(QuantumOperator):
Expand All @@ -37,31 +36,44 @@ class Observable(QuantumOperator):
`ResultType.Expectation` to specify the measurement basis.
"""

def __init__(self, qubit_count: int, ascii_symbols: Sequence[str]):
def __init__(
self, qubit_count: int, ascii_symbols: Sequence[str], targets: QubitSetInput | None = None
):
super().__init__(qubit_count=qubit_count, ascii_symbols=ascii_symbols)
if targets is not None:
targets = QubitSet(targets)
if (num_targets := len(targets)) != qubit_count:
raise ValueError(
f"Length of target {num_targets} does not match qubit count {qubit_count}"
)
self._targets = targets
else:
self._targets = None
self._coef = 1

def _unscaled(self) -> Observable:
return Observable(qubit_count=self.qubit_count, ascii_symbols=self.ascii_symbols)
return Observable(
qubit_count=self.qubit_count, ascii_symbols=self.ascii_symbols, targets=self.targets
)

def to_ir(
self,
target: QubitSet | None = None,
target: QubitSetInput | None = None,
ir_type: IRType = IRType.JAQCD,
serialization_properties: SerializationProperties | None = None,
) -> Union[str, list[Union[str, list[list[list[float]]]]]]:
) -> str | list[str | list[list[list[float]]]]:
"""Returns the IR representation for the observable
Args:
target (QubitSet | None): target qubit(s). Defaults to None.
target (QubitSetInput | None): target qubit(s). Defaults to None.
ir_type(IRType) : The IRType to use for converting the result type object to its
IR representation. Defaults to IRType.JAQCD.
serialization_properties (SerializationProperties | None): The serialization properties
to use while serializing the object to the IR representation. The serialization
properties supplied must correspond to the supplied `ir_type`. Defaults to None.
Returns:
Union[str, list[Union[str, list[list[list[float]]]]]]: The IR representation for
str | list[str | list[list[list[float]]]]: The IR representation for
the observable.
Raises:
Expand All @@ -84,27 +96,35 @@ def to_ir(
else:
raise ValueError(f"Supplied ir_type {ir_type} is not supported.")

def _to_jaqcd(self) -> list[Union[str, list[list[list[float]]]]]:
def _to_jaqcd(self) -> list[str | list[list[list[float]]]]:
"""Returns the JAQCD representation of the observable."""
raise NotImplementedError("to_jaqcd has not been implemented yet.")

def _to_openqasm(
self,
serialization_properties: OpenQASMSerializationProperties,
target: QubitSet | None = None,
targets: QubitSetInput | None = None,
) -> str:
"""Returns the openqasm string representation of the result type.
Args:
serialization_properties (OpenQASMSerializationProperties): The serialization properties
to use while serializing the object to the IR representation.
target (QubitSet | None): target qubit(s). Defaults to None.
targets (QubitSetInput | None): target qubit(s). Defaults to None.
Returns:
str: Representing the openqasm representation of the result type.
"""
raise NotImplementedError("to_openqasm has not been implemented yet.")

@property
def targets(self) -> QubitSet | None:
"""QubitSet | None: The target qubits of this observable
If `None`, this is provided by the enclosing result type.
"""
return self._targets

@property
def coefficient(self) -> int:
"""The coefficient of the observable.
Expand Down Expand Up @@ -185,7 +205,11 @@ def __sub__(self, other: Observable):
return self + (-1 * other)

def __repr__(self) -> str:
return f"{self.name}('qubit_count': {self.qubit_count})"
return (
f"{self.name}('qubit_count': {self._qubit_count})"
if self._targets is None
else f"{self.name}('qubit_count': {self._qubit_count}, 'target': {self._targets})"
)

def __eq__(self, other: Observable) -> bool:
if isinstance(other, Observable):
Expand All @@ -198,8 +222,12 @@ class StandardObservable(Observable):
eigenvalues of (+1, -1).
"""

def __init__(self, ascii_symbols: Sequence[str]):
super().__init__(qubit_count=1, ascii_symbols=ascii_symbols)
def __init__(self, ascii_symbols: Sequence[str], target: QubitInput | None = None):
super().__init__(
qubit_count=1,
ascii_symbols=ascii_symbols,
targets=[target] if target is not None else None,
)
self._eigenvalues = (1.0, -1.0) # immutable

def _unscaled(self) -> StandardObservable:
Expand Down
Loading

0 comments on commit b095710

Please sign in to comment.