Skip to content

Commit

Permalink
fix: Flatten observable before getting targets (#287)
Browse files Browse the repository at this point in the history
Fixes #285
  • Loading branch information
speller26 authored Jan 16, 2025
1 parent 2656be0 commit 386985a
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 13 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
install_requires=[
"amazon-braket-sdk>=1.87.0",
"autoray>=0.6.11",
"pennylane>=0.34.0,<0.40",
"pennylane>=0.34.0",
],
entry_points={
"pennylane.plugins": [
Expand All @@ -53,7 +53,7 @@
},
extras_require={
"test": [
"autoray<0.7.0", # autoray.tensorflow_diag no longer works
"autoray<0.7.0", # autoray.tensorflow_diag no longer works
"docutils>=0.19",
"flaky",
"pre-commit",
Expand Down
3 changes: 2 additions & 1 deletion src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from braket.device_schema import DeviceActionType
from braket.devices import Device, LocalSimulator
from braket.pennylane_plugin.translation import (
flatten_observable,
get_adjoint_gradient_result_type,
supported_observables,
supported_operations,
Expand Down Expand Up @@ -281,7 +282,7 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
f" observable, not {len(circuit.observables)} observables."
)
pl_measurements = circuit.measurements[0]
pl_observable = pl_measurements.obs
pl_observable = flatten_observable(pl_measurements.obs)
if pl_measurements.return_type != Expectation:
raise ValueError(
f"Braket can only compute gradients for circuits with a single expectation"
Expand Down
8 changes: 4 additions & 4 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def get_adjoint_gradient_result_type(
if "AdjointGradient" not in supported_result_types:
raise NotImplementedError("Unsupported return type: AdjointGradient")

braket_observable = _translate_observable(_flatten_observable(observable))
braket_observable = _translate_observable(observable)
braket_observable = (
braket_observable.item() if hasattr(braket_observable, "item") else braket_observable
)
Expand Down Expand Up @@ -590,7 +590,7 @@ def translate_result_type( # noqa: C901
return tuple(Sample(observables.Z(target)) for target in targets or measurement.wires)
raise NotImplementedError(f"Unsupported return type: {return_type}")

observable = _flatten_observable(observable)
observable = flatten_observable(observable)

if isinstance(observable, qml.ops.LinearCombination):
if return_type is ObservableReturnTypes.Expectation:
Expand All @@ -608,7 +608,7 @@ def translate_result_type( # noqa: C901
raise NotImplementedError(f"Unsupported return type: {return_type}")


def _flatten_observable(observable):
def flatten_observable(observable):
if isinstance(observable, (qml.ops.CompositeOp, qml.ops.SProd)):
simplified = qml.ops.LinearCombination(*observable.terms()).simplify()
coeffs, _ = simplified.terms()
Expand Down Expand Up @@ -735,7 +735,7 @@ def translate_result(
return dict(braket_result.measurement_counts)

translated = translate_result_type(measurement, targets, supported_result_types)
observable = _flatten_observable(observable)
observable = flatten_observable(observable)
if isinstance(observable, qml.ops.LinearCombination):
coeffs, _ = observable.terms()
return sum(
Expand Down
8 changes: 4 additions & 4 deletions test/integ_tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_qubit_state_vector(self, init_state, device, tol):

@qml.qnode(dev)
def circuit():
qml.QubitStateVector.compute_decomposition(state, wires=[0])
qml.StatePrep.compute_decomposition(state, wires=[0])
return qml.probs(wires=range(1))

assert np.allclose(circuit(), np.abs(state) ** 2, **tol)
Expand Down Expand Up @@ -177,15 +177,15 @@ def test_qubit_channel(self, init_state, dm_device, kraus, tol):
def assert_op_and_inverse(op, dev, state, wires, tol, op_args):
@qml.qnode(dev)
def circuit():
qml.QubitStateVector.compute_decomposition(state, wires=wires)
qml.StatePrep.compute_decomposition(state, wires=wires)
op(*op_args, wires=wires)
return qml.probs(wires=wires)

assert np.allclose(circuit(), np.abs(op.compute_matrix(*op_args) @ state) ** 2, **tol)

@qml.qnode(dev)
def circuit_inv():
qml.QubitStateVector.compute_decomposition(state, wires=wires)
qml.StatePrep.compute_decomposition(state, wires=wires)
qml.adjoint(op(*op_args, wires=wires))
return qml.probs(wires=wires)

Expand All @@ -197,7 +197,7 @@ def circuit_inv():
def assert_noise_op(op, dev, state, wires, tol, op_args):
@qml.qnode(dev)
def circuit():
qml.QubitStateVector.compute_decomposition(state, wires=wires)
qml.StatePrep.compute_decomposition(state, wires=wires)
op(*op_args, wires=wires)
return qml.probs(wires=wires)

Expand Down
6 changes: 4 additions & 2 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,10 @@ def test_execute_parametrize_differentiable(mock_run):
qml.RY(0.543, wires=0),
],
measurements=[
qml.expval(2 * qml.PauliX(0) @ qml.PauliY(1) + 0.75 * qml.PauliY(0) @ qml.PauliZ(1)),
qml.expval(
2 * qml.PauliX(0) @ qml.PauliY(1) @ qml.Identity(2)
+ 0.75 * qml.PauliY(0) @ qml.PauliZ(1)
),
],
)
CIRCUIT_3.trainable_params = [0, 1]
Expand Down Expand Up @@ -569,7 +572,6 @@ def test_execute_with_gradient_no_op_math(
result_types,
expected_pl_result,
):

task = Mock()
type(task).id = PropertyMock(return_value="task_arn")
task.state.return_value = "COMPLETED"
Expand Down

0 comments on commit 386985a

Please sign in to comment.