Skip to content

Commit

Permalink
Support qml.counts for circuits (#267)
Browse files Browse the repository at this point in the history
* feature: support counts

* test: add tests

* test: add unit tests

* change: remove import

* test: add test for unsupported return type

---------

Co-authored-by: Tim (Yi-Ting) <[email protected]>
Co-authored-by: Ryan Shaffer <[email protected]>
  • Loading branch information
3 people authored Jun 17, 2024
1 parent 7b9f700 commit d7c232f
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 14 deletions.
7 changes: 3 additions & 4 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from pennylane import numpy as np
from pennylane.gradients import param_shift
from pennylane.measurements import (
Counts,
Expectation,
MeasurementProcess,
MeasurementTransform,
Expand All @@ -78,7 +79,7 @@

from ._version import __version__

RETURN_TYPES = [Expectation, Variance, Sample, Probability, State]
RETURN_TYPES = [Expectation, Variance, Sample, Probability, State, Counts]
MIN_SIMULATOR_BILLED_MS = 3000
OBS_LIST = (qml.PauliX, qml.PauliY, qml.PauliZ)

Expand Down Expand Up @@ -260,9 +261,7 @@ def statistics(
results = []
for mp in measurements:
if mp.return_type not in RETURN_TYPES:
raise QuantumFunctionError(
"Unsupported return type specified for observable {}".format(mp.obs.name)
)
raise QuantumFunctionError("Unsupported return type: {}".format(mp.return_type))
results.append(self._get_statistic(braket_result, mp))
return results

Expand Down
32 changes: 27 additions & 5 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from collections import Counter
from functools import partial, reduce, singledispatch
from typing import Any, Optional, Union

Expand Down Expand Up @@ -530,7 +531,7 @@ def get_adjoint_gradient_result_type(
return AdjointGradient(observable=braket_observable, target=targets, parameters=parameters)


def translate_result_type(
def translate_result_type( # noqa: C901
measurement: MeasurementProcess, targets: list[int], supported_result_types: frozenset[str]
) -> Union[ResultType, tuple[ResultType, ...]]:
"""Translates a PennyLane ``MeasurementProcess`` into the corresponding Braket ``ResultType``.
Expand All @@ -547,6 +548,7 @@ def translate_result_type(
then this will return a result type for each term.
"""
return_type = measurement.return_type
observable = measurement.obs

if return_type is ObservableReturnTypes.Probability:
return Probability(targets)
Expand All @@ -558,19 +560,24 @@ def translate_result_type(
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")

if isinstance(measurement.obs, (Hamiltonian, qml.Hamiltonian)):
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
if return_type is ObservableReturnTypes.Expectation:
return tuple(
Expectation(_translate_observable(term), term.wires) for term in measurement.obs.ops
Expectation(_translate_observable(term), term.wires) for term in observable.ops
)
raise NotImplementedError(f"Return type {return_type} unsupported for Hamiltonian")

braket_observable = _translate_observable(measurement.obs)
if observable is None:
if return_type is ObservableReturnTypes.Counts:
return tuple(Sample(observables.Z(), target) for target in targets or measurement.wires)
raise NotImplementedError(f"Unsupported return type: {return_type}")

braket_observable = _translate_observable(observable)
if return_type is ObservableReturnTypes.Expectation:
return Expectation(braket_observable, targets)
elif return_type is ObservableReturnTypes.Variance:
return Variance(braket_observable, targets)
elif return_type is ObservableReturnTypes.Sample:
elif return_type in (ObservableReturnTypes.Sample, ObservableReturnTypes.Counts):
return Sample(braket_observable, targets)
else:
raise NotImplementedError(f"Unsupported return type: {return_type}")
Expand Down Expand Up @@ -698,12 +705,27 @@ def translate_result(
ag_result.value["gradient"][f"p_{i}"]
for i in sorted(key_indices)
]

if measurement.return_type is ObservableReturnTypes.Counts and observable is None:
if targets:
new_dict = {}
for key, value in braket_result.measurement_counts.items():
new_key = "".join(key[i] for i in targets)
if new_key not in new_dict:
new_dict[new_key] = 0
new_dict[new_key] += value
return new_dict

return dict(braket_result.measurement_counts)

translated = translate_result_type(measurement, targets, supported_result_types)
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
coeffs, _ = observable.terms()
return sum(
coeff * braket_result.get_value_by_result_type(result_type)
for coeff, result_type in zip(coeffs, translated)
)
elif measurement.return_type is ObservableReturnTypes.Counts:
return dict(Counter(braket_result.get_value_by_result_type(translated)))
else:
return braket_result.get_value_by_result_type(translated)
87 changes: 87 additions & 0 deletions test/integ_tests/test_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""Tests that counts are correctly computed in the plugin device"""

import numpy as np
import pennylane as qml
import pytest

np.random.seed(42)


@pytest.mark.parametrize("shots", [8192])
class TestCounts:
"""Tests for the count return type"""

def test_counts_values(self, device, shots, tol):
"""Tests if the result returned by counts have
the correct values
"""
dev = device(2)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 3, wires=0)
return qml.counts()

result = circuit().item()

# The sample should only contain 00 and 10
assert "00" in result
assert "10" in result
assert "11" not in result
assert "01" not in result
assert result["00"] + result["10"] == shots
assert np.allclose(result["00"] / shots, 0.75, **tol)
assert np.allclose(result["10"] / shots, 0.25, **tol)

def test_counts_values_specify_target(self, device, shots, tol):
"""Tests if the result returned by counts have
the correct values when specifying a target
"""
dev = device(2)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 3, wires=0)
return qml.counts(wires=[0])

result = circuit().item()

# The sample should only contain 00 and 10
assert "0" in result
assert "1" in result
assert result["0"] + result["1"] == shots
assert np.allclose(result["0"] / shots, 0.75, **tol)
assert np.allclose(result["1"] / shots, 0.25, **tol)

def test_counts_values_with_observable(self, device, shots, tol):
"""Tests if the result returned by counts have
the correct values when specifying an observable
"""
dev = device(2)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 3, wires=0)
return qml.counts(op=qml.PauliZ(wires=[0]))

result = circuit().item()

# The sample should only contain 00 and 10
assert -1 in result
assert 1 in result
assert result[-1] + result[1] == shots
assert np.allclose(result[1] / shots, 0.75, **tol)
assert np.allclose(result[-1] / shots, 0.25, **tol)
172 changes: 169 additions & 3 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# language governing permissions and limitations under the License.

import json
from collections import Counter
from enum import Enum
from typing import Any, Optional
from unittest import mock
from unittest.mock import Mock, PropertyMock, patch
Expand Down Expand Up @@ -862,9 +864,8 @@ def test_parametrized_evolution_in_oqc_lucy_supported_ops():
def test_bad_statistics():
"""Test if a QuantumFunctionError is raised for an invalid return type"""
dev = _aws_device(wires=1, foo="bar")
observable = qml.Identity(wires=0)
tape = qml.tape.QuantumTape(measurements=[qml.counts(observable)])
with pytest.raises(QuantumFunctionError, match="Unsupported return type specified"):
tape = qml.tape.QuantumTape(measurements=[qml.classical_shadow(wires=[0])])
with pytest.raises(QuantumFunctionError, match="Unsupported return type:"):
dev.statistics(None, tape.measurements)


Expand Down Expand Up @@ -1233,6 +1234,171 @@ def test_execute_some_samples(mock_run):
assert results[1] == 0.0


@patch.object(AwsDevice, "run")
@pytest.mark.parametrize(
"num_wires, op, wires, measurements, measurement_counts, result_types, expected_result",
[
(
2,
None,
None,
[[0, 0], [1, 1], [0, 0], [1, 1]],
Counter({"00": 2, "11": 2}),
[
{
"type": {"observable": ["z", "z"], "targets": [0, 1], "type": "sample"},
"value": [1, 1, 1, 1],
},
],
{"00": 2, "11": 2},
),
(
2,
qml.PauliZ(0),
None,
[[0, 0], [1, 1], [0, 0], [1, 1]],
Counter({"00": 2, "11": 2}),
[
{
"type": {"observable": ["z"], "targets": [0], "type": "sample"},
"value": [1, -1, 1, -1],
},
],
{1: 2, -1: 2},
),
(
2,
None,
[0],
[[0, 0], [1, 1], [0, 0], [1, 1]],
Counter({"00": 2, "11": 2}),
[
{
"type": {"observable": ["z", "z"], "targets": [0, 1], "type": "sample"},
"value": [1, 1, 1, 1],
},
],
{"0": 2, "1": 2},
),
(
3,
None,
[2],
[[0, 0, 0], [1, 1, 0], [0, 0, 0], [1, 1, 0]],
Counter({"000": 2, "110": 2}),
[
{
"type": {"observable": ["z"], "targets": [2], "type": "sample"},
"value": [1, 1, 1, 1],
},
],
{"0": 4},
),
],
)
def test_execute_counts(
mock_run,
num_wires,
op,
wires,
measurements,
measurement_counts,
result_types,
expected_result,
):
result = GateModelQuantumTaskResult.from_string(
json.dumps(
{
"braketSchemaHeader": {
"name": "braket.task_result.gate_model_task_result",
"version": "1",
},
"measurements": measurements,
"measurement_counts": measurement_counts,
"resultTypes": result_types,
"measuredQubits": [0, 1],
"taskMetadata": {
"braketSchemaHeader": {
"name": "braket.task_result.task_metadata",
"version": "1",
},
"id": "task_arn",
"shots": 4,
"deviceId": "default",
},
"additionalMetadata": {
"action": {
"braketSchemaHeader": {
"name": "braket.ir.openqasm.program",
"version": "1",
},
"source": "qubit[2] q; cnot q[0], q[1]; measure q;",
},
},
}
)
)

task = Mock()
task.result.return_value = result
mock_run.return_value = task

dev = _aws_device(wires=num_wires, shots=4)

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.counts(op=op, wires=wires)

results = dev.execute(circuit)

assert results == expected_result


def test_counts_all_outcomes_fails():
"""Tests that the calling counts with 'all_outcomes=True' raises an error"""
dev = _aws_device(wires=2, shots=4)

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.counts(all_outcomes=True)

does_not_support = "Unsupported return type: ObservableReturnTypes.AllCounts"
with pytest.raises(NotImplementedError, match=does_not_support):
dev.execute(circuit)


def test_sample_fails():
"""Tests that the calling sample raises an error"""
dev = _aws_device(wires=2, shots=4)

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.sample()

does_not_support = "Unsupported return type: ObservableReturnTypes.Sample"
with pytest.raises(NotImplementedError, match=does_not_support):
dev.execute(circuit)


def test_unsupported_return_type():
"""Tests that using an unsupported return type for measurement raises an error"""
dev = _aws_device(wires=2, shots=4)

mock_measurement = Mock()
mock_measurement.return_type = Enum("ObservableReturnTypes", {"Foo": "foo"}).Foo
mock_measurement.obs = qml.PauliZ(0)
mock_measurement.wires = qml.wires.Wires([0])

tape = qml.tape.QuantumTape(measurements=[mock_measurement])

does_not_support = "Unsupported return type: ObservableReturnTypes.Foo"
with pytest.raises(NotImplementedError, match=does_not_support):
dev.execute(tape)


@patch.object(AwsDevice, "type", new_callable=mock.PropertyMock)
@patch.object(AwsDevice, "name", new_callable=mock.PropertyMock)
def test_non_circuit_device(name_mock, type_mock):
Expand Down
Loading

0 comments on commit d7c232f

Please sign in to comment.