Skip to content

Commit

Permalink
Allow early qubit binding of observables
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 committed Aug 19, 2024
1 parent ee63777 commit 9b0b479
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 65 deletions.
52 changes: 39 additions & 13 deletions src/braket/circuits/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import numbers
from collections.abc import Sequence
from copy import deepcopy
from typing import Union

import numpy as np

Expand All @@ -27,7 +26,7 @@
OpenQASMSerializationProperties,
SerializationProperties,
)
from braket.registers.qubit_set import QubitSet
from braket.registers import QubitInput, QubitSet, QubitSetInput


class Observable(QuantumOperator):
Expand All @@ -37,31 +36,42 @@ class Observable(QuantumOperator):
`ResultType.Expectation` to specify the measurement basis.
"""

def __init__(self, qubit_count: int, ascii_symbols: Sequence[str]):
def __init__(
self, qubit_count: int, ascii_symbols: Sequence[str], targets: QubitSetInput | None = None
):
super().__init__(qubit_count=qubit_count, ascii_symbols=ascii_symbols)
if targets is not None:
targets = QubitSet(targets)
if (num_targets := len(targets)) != qubit_count:
raise ValueError(
f"Length of target {num_targets} does not match qubit count {qubit_count}"
)
self._targets = targets
else:
self._targets = None
self._coef = 1

def _unscaled(self) -> Observable:
return Observable(qubit_count=self.qubit_count, ascii_symbols=self.ascii_symbols)

def to_ir(
self,
target: QubitSet | None = None,
target: QubitSetInput | None = None,
ir_type: IRType = IRType.JAQCD,
serialization_properties: SerializationProperties | None = None,
) -> Union[str, list[Union[str, list[list[list[float]]]]]]:
) -> str | list[str | list[list[list[float]]]]:
"""Returns the IR representation for the observable
Args:
target (QubitSet | None): target qubit(s). Defaults to None.
target (QubitSetInput | None): target qubit(s). Defaults to None.
ir_type(IRType) : The IRType to use for converting the result type object to its
IR representation. Defaults to IRType.JAQCD.
serialization_properties (SerializationProperties | None): The serialization properties
to use while serializing the object to the IR representation. The serialization
properties supplied must correspond to the supplied `ir_type`. Defaults to None.
Returns:
Union[str, list[Union[str, list[list[list[float]]]]]]: The IR representation for
str | list[str | list[list[list[float]]]]: The IR representation for
the observable.
Raises:
Expand All @@ -84,27 +94,35 @@ def to_ir(
else:
raise ValueError(f"Supplied ir_type {ir_type} is not supported.")

def _to_jaqcd(self) -> list[Union[str, list[list[list[float]]]]]:
def _to_jaqcd(self) -> list[str | list[list[list[float]]]]:
"""Returns the JAQCD representation of the observable."""
raise NotImplementedError("to_jaqcd has not been implemented yet.")

def _to_openqasm(
self,
serialization_properties: OpenQASMSerializationProperties,
target: QubitSet | None = None,
targets: QubitSetInput | None = None,
) -> str:
"""Returns the openqasm string representation of the result type.
Args:
serialization_properties (OpenQASMSerializationProperties): The serialization properties
to use while serializing the object to the IR representation.
target (QubitSet | None): target qubit(s). Defaults to None.
targets (QubitSetInput | None): target qubit(s). Defaults to None.
Returns:
str: Representing the openqasm representation of the result type.
"""
raise NotImplementedError("to_openqasm has not been implemented yet.")

@property
def targets(self) -> QubitSet | None:
"""QubitSet | None: The target qubits of this observable
If `None`, this is provided by the enclosing result type.
"""
return self._targets

@property
def coefficient(self) -> int:
"""The coefficient of the observable.
Expand Down Expand Up @@ -185,7 +203,11 @@ def __sub__(self, other: Observable):
return self + (-1 * other)

def __repr__(self) -> str:
return f"{self.name}('qubit_count': {self.qubit_count})"
return (
f"{self.name}('qubit_count': {self._qubit_count})"
if self._targets is None
else f"{self.name}('qubit_count': {self._qubit_count}, 'target': {self._targets})"
)

def __eq__(self, other: Observable) -> bool:
if isinstance(other, Observable):
Expand All @@ -198,8 +220,12 @@ class StandardObservable(Observable):
eigenvalues of (+1, -1).
"""

def __init__(self, ascii_symbols: Sequence[str]):
super().__init__(qubit_count=1, ascii_symbols=ascii_symbols)
def __init__(self, ascii_symbols: Sequence[str], target: QubitInput | None = None):
super().__init__(
qubit_count=1,
ascii_symbols=ascii_symbols,
targets=[target] if target is not None else None
)
self._eigenvalues = (1.0, -1.0) # immutable

def _unscaled(self) -> StandardObservable:
Expand Down
Loading

0 comments on commit 9b0b479

Please sign in to comment.