Skip to content

Commit

Permalink
to_ir test
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 committed Aug 19, 2024
1 parent a9f4706 commit dc6a99b
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 25 deletions.
4 changes: 3 additions & 1 deletion src/braket/circuits/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(
self._coef = 1

def _unscaled(self) -> Observable:
return Observable(qubit_count=self.qubit_count, ascii_symbols=self.ascii_symbols)
return Observable(
qubit_count=self.qubit_count, ascii_symbols=self.ascii_symbols, targets=self.targets
)

def to_ir(
self,
Expand Down
26 changes: 16 additions & 10 deletions src/braket/circuits/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, target: QubitInput | None = None):
super().__init__(ascii_symbols=["H"], target=target)

def _unscaled(self) -> StandardObservable:
return H()
return H(self._targets)

def _to_jaqcd(self) -> list[str]:
if self.coefficient != 1:
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(self, target: QubitInput | None = None):
super().__init__(qubit_count=1, ascii_symbols=["I"], targets=target)

def _unscaled(self) -> Observable:
return I()
return I(self._targets)

def _to_jaqcd(self) -> list[str]:
if self.coefficient != 1:
Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(self, target: QubitInput | None = None):
super().__init__(ascii_symbols=["X"], target=target)

def _unscaled(self) -> StandardObservable:
return X()
return X(self._targets)

def _to_jaqcd(self) -> list[str]:
if self.coefficient != 1:
Expand Down Expand Up @@ -191,7 +191,7 @@ def __init__(self, target: QubitInput | None = None):
super().__init__(ascii_symbols=["Y"], target=target)

def _unscaled(self) -> StandardObservable:
return Y()
return Y(self._targets)

def _to_jaqcd(self) -> list[str]:
if self.coefficient != 1:
Expand Down Expand Up @@ -234,7 +234,7 @@ def __init__(self, target: QubitInput | None = None):
super().__init__(ascii_symbols=["Z"], target=target)

def _unscaled(self) -> StandardObservable:
return Z()
return Z(self._targets)

def _to_jaqcd(self) -> list[str]:
if self.coefficient != 1:
Expand Down Expand Up @@ -315,7 +315,7 @@ def __init__(self, observables: list[Observable]):
if len(merged_targets) != len(flat_targets):
raise ValueError("Cannot have repeated target qubits")

Check warning on line 316 in src/braket/circuits/observables.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/observables.py#L316

Added line #L316 was not covered by tests
else:
raise ValueError("Cannot mix observables with and without targets")
raise ValueError("Cannot mix factors with and without targets")

Check warning on line 318 in src/braket/circuits/observables.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/observables.py#L318

Added line #L318 was not covered by tests

super().__init__(
qubit_count=qubit_count,
Expand Down Expand Up @@ -492,13 +492,13 @@ def __init__(self, observables: list[Observable], display_name: str = "Hamiltoni

self._summands = tuple(flattened_observables)
qubit_count = max(flattened_observables, key=lambda obs: obs.qubit_count).qubit_count
all_targets = [observable for observable in flattened_observables]
all_targets = [observable.targets for observable in flattened_observables]
if all(targets is None for targets in all_targets):
targets = None
elif all(targets is not None for targets in all_targets):
targets = all_targets
else:
raise ValueError("Cannot mix observables with and without targets")
raise ValueError("Cannot mix terms with and without targets")

Check warning on line 501 in src/braket/circuits/observables.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/observables.py#L501

Added line #L501 was not covered by tests
super().__init__(qubit_count=qubit_count, ascii_symbols=[display_name] * qubit_count)
self._targets = targets

Expand All @@ -519,6 +519,7 @@ def _to_openqasm(
serialization_properties: OpenQASMSerializationProperties,
target: list[QubitSetInput] = None,
) -> str:
target = target or self._targets
if len(self.summands) != len(target):
raise ValueError(
f"Invalid target of length {len(target)} for Sum with {len(self.summands)} terms"
Expand Down Expand Up @@ -613,10 +614,14 @@ def __init__(
Gate.Unitary(matrix=eigendecomposition["eigenvectors"].conj().T),
)

super().__init__(qubit_count=qubit_count, ascii_symbols=[display_name] * qubit_count)
super().__init__(
qubit_count=qubit_count, ascii_symbols=[display_name] * qubit_count, targets=targets
)

def _unscaled(self) -> Observable:
return Hermitian(matrix=self._matrix, display_name=self.ascii_symbols[0])
return Hermitian(
matrix=self._matrix, display_name=self.ascii_symbols[0], targets=self._targets
)

def _to_jaqcd(self) -> list[list[list[list[float]]]]:
if self.coefficient != 1:
Expand All @@ -629,6 +634,7 @@ def _to_openqasm(
self, serialization_properties: OpenQASMSerializationProperties, target: QubitSet = None
) -> str:
coef_prefix = f"{self.coefficient} * " if self.coefficient != 1 else ""
target = target or self._targets
if target:
qubit_target = ", ".join(
[serialization_properties.format_target(int(t)) for t in target]
Expand Down
Loading

0 comments on commit dc6a99b

Please sign in to comment.