Skip to content

Commit

Permalink
feature: support LinearCombination as observable (#246)
Browse files Browse the repository at this point in the history
* feature: support LinearCombination as observable

* use qml.Hamiltonian instead of LinearCombination

* use Hamiltonian import
  • Loading branch information
ashlhans authored May 3, 2024
1 parent 184745c commit 012a3e2
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/braket/pennylane_plugin/ahs_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from typing import Optional, Union

import numpy as np
import pennylane as qml
from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation
from braket.aws import AwsDevice, AwsQuantumTask, AwsSession
from braket.devices import Device, LocalSimulator
Expand Down Expand Up @@ -308,7 +309,7 @@ def _validate_measurement_basis(self, observable):
if isinstance(observable, CompositeOp):
for op in observable.operands:
self._validate_measurement_basis(op)
elif isinstance(observable, Hamiltonian):
elif isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
for op in observable.ops:
self._validate_measurement_basis(op)

Expand Down
6 changes: 3 additions & 3 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
Variance,
)
from pennylane.operation import Operation
from pennylane.ops.qubit.hamiltonian import Hamiltonian
from pennylane.ops import Hamiltonian
from pennylane.tape import QuantumTape

from braket.pennylane_plugin.translation import (
Expand Down Expand Up @@ -166,7 +166,7 @@ def observables(self) -> frozenset[str]:
# This needs to be here bc expectation(ax+by)== a*expectation(x)+b*expectation(y)
# is only true when shots=0
if not self.shots:
return base_observables.union({"Hamiltonian"})
return base_observables.union({"Hamiltonian", "LinearCombination"})
return base_observables

@property
Expand Down Expand Up @@ -227,7 +227,7 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
f"Braket can only compute gradients for circuits with a single expectation"
f" observable, not a {pl_measurements.return_type} observable."
)
if isinstance(pl_observable, Hamiltonian):
if isinstance(pl_observable, (Hamiltonian, qml.Hamiltonian)):
targets = [self.map_wires(op.wires) for op in pl_observable.ops]
else:
targets = self.map_wires(pl_observable.wires).tolist()
Expand Down
21 changes: 16 additions & 5 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pennylane import numpy as np
from pennylane.measurements import MeasurementProcess, ObservableReturnTypes
from pennylane.operation import Observable, Operation
from pennylane.ops import Adjoint
from pennylane.ops import Adjoint, Hamiltonian
from pennylane.pulse import ParametrizedEvolution

from braket.pennylane_plugin.ops import (
Expand Down Expand Up @@ -558,7 +558,7 @@ def translate_result_type(
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")

if isinstance(measurement.obs, qml.Hamiltonian):
if isinstance(measurement.obs, (Hamiltonian, qml.Hamiltonian)):
if return_type is ObservableReturnTypes.Expectation:
return tuple(
Expectation(_translate_observable(term), term.wires) for term in measurement.obs.ops
Expand All @@ -581,8 +581,9 @@ def _translate_observable(observable):
raise qml.DeviceError(f"Unsupported observable: {type(observable)}")


@_translate_observable.register
def _(H: qml.Hamiltonian):
@_translate_observable.register(Hamiltonian)
@_translate_observable.register(qml.Hamiltonian)
def _(H: Union[Hamiltonian, qml.Hamiltonian]):
# terms is structured like [C, O] where C is a tuple of all the coefficients, and O is
# a tuple of all the corresponding observable terms (X, Y, Z, H, etc or a tensor product
# of them)
Expand Down Expand Up @@ -651,6 +652,16 @@ def _(t: qml.ops.Prod):
return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands])


@_translate_observable.register
def _(t: qml.ops.SProd):
return t.scalar * _translate_observable(t.base)


@_translate_observable.register
def _(t: qml.ops.Sum):
return reduce(lambda x, y: x + y, [_translate_observable(operator) for operator in t.operands])


def translate_result(
braket_result: GateModelQuantumTaskResult,
measurement: MeasurementProcess,
Expand Down Expand Up @@ -688,7 +699,7 @@ def translate_result(
for i in sorted(key_indices)
]
translated = translate_result_type(measurement, targets, supported_result_types)
if isinstance(observable, qml.Hamiltonian):
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
coeffs, _ = observable.terms()
return sum(
coeff * braket_result.get_value_by_result_type(result_type)
Expand Down
2 changes: 2 additions & 0 deletions test/unit_tests/test_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,8 @@ def _result_meta() -> dict:
),
(1.25 * observables.H(), 1.25 * qml.Hadamard(wires=0)),
(observables.X() @ observables.Y(), qml.ops.Prod(qml.PauliX(0), qml.PauliY(1))),
(observables.X() + observables.Y(), qml.ops.Sum(qml.PauliX(0), qml.PauliY(1))),
(observables.X(), qml.ops.SProd(scalar=4, base=qml.PauliX(0))),
],
)
def test_translate_hamiltonian_observable(expected_braket_H, pl_H):
Expand Down

0 comments on commit 012a3e2

Please sign in to comment.