Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support qml.counts for circuits #267

Merged
merged 7 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from pennylane import numpy as np
from pennylane.gradients import param_shift
from pennylane.measurements import (
Counts,
Expectation,
MeasurementProcess,
MeasurementTransform,
Expand All @@ -78,7 +79,7 @@

from ._version import __version__

RETURN_TYPES = [Expectation, Variance, Sample, Probability, State]
RETURN_TYPES = [Expectation, Variance, Sample, Probability, State, Counts]
MIN_SIMULATOR_BILLED_MS = 3000
OBS_LIST = (qml.PauliX, qml.PauliY, qml.PauliZ)

Expand Down Expand Up @@ -260,9 +261,7 @@ def statistics(
results = []
for mp in measurements:
if mp.return_type not in RETURN_TYPES:
raise QuantumFunctionError(
"Unsupported return type specified for observable {}".format(mp.obs.name)
)
raise QuantumFunctionError("Unsupported return type: {}".format(mp.return_type))
results.append(self._get_statistic(braket_result, mp))
return results

Expand Down
32 changes: 27 additions & 5 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from collections import Counter
from functools import partial, reduce, singledispatch
from typing import Any, Optional, Union

Expand Down Expand Up @@ -530,7 +531,7 @@ def get_adjoint_gradient_result_type(
return AdjointGradient(observable=braket_observable, target=targets, parameters=parameters)


def translate_result_type(
def translate_result_type( # noqa: C901
measurement: MeasurementProcess, targets: list[int], supported_result_types: frozenset[str]
) -> Union[ResultType, tuple[ResultType, ...]]:
"""Translates a PennyLane ``MeasurementProcess`` into the corresponding Braket ``ResultType``.
Expand All @@ -547,6 +548,7 @@ def translate_result_type(
then this will return a result type for each term.
"""
return_type = measurement.return_type
observable = measurement.obs

if return_type is ObservableReturnTypes.Probability:
return Probability(targets)
Expand All @@ -558,19 +560,24 @@ def translate_result_type(
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")

if isinstance(measurement.obs, (Hamiltonian, qml.Hamiltonian)):
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
if return_type is ObservableReturnTypes.Expectation:
return tuple(
Expectation(_translate_observable(term), term.wires) for term in measurement.obs.ops
Expectation(_translate_observable(term), term.wires) for term in observable.ops
)
raise NotImplementedError(f"Return type {return_type} unsupported for Hamiltonian")

braket_observable = _translate_observable(measurement.obs)
if observable is None:
if return_type is ObservableReturnTypes.Counts:
return tuple(Sample(observables.Z(), target) for target in targets or measurement.wires)
raise NotImplementedError(f"Unsupported return type: {return_type}")

braket_observable = _translate_observable(observable)
if return_type is ObservableReturnTypes.Expectation:
return Expectation(braket_observable, targets)
elif return_type is ObservableReturnTypes.Variance:
return Variance(braket_observable, targets)
elif return_type is ObservableReturnTypes.Sample:
elif return_type in (ObservableReturnTypes.Sample, ObservableReturnTypes.Counts):
return Sample(braket_observable, targets)
else:
raise NotImplementedError(f"Unsupported return type: {return_type}")
Expand Down Expand Up @@ -698,12 +705,27 @@ def translate_result(
ag_result.value["gradient"][f"p_{i}"]
for i in sorted(key_indices)
]

if measurement.return_type is ObservableReturnTypes.Counts and observable is None:
if targets:
new_dict = {}
for key, value in braket_result.measurement_counts.items():
new_key = "".join(key[i] for i in targets)
if new_key not in new_dict:
new_dict[new_key] = 0
new_dict[new_key] += value
return new_dict

return dict(braket_result.measurement_counts)

translated = translate_result_type(measurement, targets, supported_result_types)
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
coeffs, _ = observable.terms()
return sum(
coeff * braket_result.get_value_by_result_type(result_type)
for coeff, result_type in zip(coeffs, translated)
)
elif measurement.return_type is ObservableReturnTypes.Counts:
return dict(Counter(braket_result.get_value_by_result_type(translated)))
else:
return braket_result.get_value_by_result_type(translated)
87 changes: 87 additions & 0 deletions test/integ_tests/test_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""Tests that counts are correctly computed in the plugin device"""

import numpy as np
import pennylane as qml
import pytest

np.random.seed(42)


@pytest.mark.parametrize("shots", [8192])
class TestCounts:
"""Tests for the count return type"""

def test_counts_values(self, device, shots, tol):
"""Tests if the result returned by counts have
the correct values
"""
dev = device(2)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 3, wires=0)
return qml.counts()

result = circuit().item()

# The sample should only contain 00 and 10
assert "00" in result
assert "10" in result
assert "11" not in result
assert "01" not in result
assert result["00"] + result["10"] == shots
assert np.allclose(result["00"] / shots, 0.75, **tol)
assert np.allclose(result["10"] / shots, 0.25, **tol)

def test_counts_values_specify_target(self, device, shots, tol):
"""Tests if the result returned by counts have
the correct values when specifying a target
"""
dev = device(2)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 3, wires=0)
return qml.counts(wires=[0])

result = circuit().item()

# The sample should only contain 00 and 10
assert "0" in result
assert "1" in result
assert result["0"] + result["1"] == shots
assert np.allclose(result["0"] / shots, 0.75, **tol)
assert np.allclose(result["1"] / shots, 0.25, **tol)

def test_counts_values_with_observable(self, device, shots, tol):
"""Tests if the result returned by counts have
the correct values when specifying an observable
"""
dev = device(2)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 3, wires=0)
return qml.counts(op=qml.PauliZ(wires=[0]))

result = circuit().item()

# The sample should only contain 00 and 10
assert -1 in result
assert 1 in result
assert result[-1] + result[1] == shots
assert np.allclose(result[1] / shots, 0.75, **tol)
assert np.allclose(result[-1] / shots, 0.25, **tol)
172 changes: 169 additions & 3 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# language governing permissions and limitations under the License.

import json
from collections import Counter
from enum import Enum
from typing import Any, Optional
from unittest import mock
from unittest.mock import Mock, PropertyMock, patch
Expand Down Expand Up @@ -862,9 +864,8 @@ def test_parametrized_evolution_in_oqc_lucy_supported_ops():
def test_bad_statistics():
"""Test if a QuantumFunctionError is raised for an invalid return type"""
dev = _aws_device(wires=1, foo="bar")
observable = qml.Identity(wires=0)
tape = qml.tape.QuantumTape(measurements=[qml.counts(observable)])
with pytest.raises(QuantumFunctionError, match="Unsupported return type specified"):
tape = qml.tape.QuantumTape(measurements=[qml.classical_shadow(wires=[0])])
with pytest.raises(QuantumFunctionError, match="Unsupported return type:"):
dev.statistics(None, tape.measurements)


Expand Down Expand Up @@ -1233,6 +1234,171 @@ def test_execute_some_samples(mock_run):
assert results[1] == 0.0


@patch.object(AwsDevice, "run")
@pytest.mark.parametrize(
"num_wires, op, wires, measurements, measurement_counts, result_types, expected_result",
[
(
2,
None,
None,
[[0, 0], [1, 1], [0, 0], [1, 1]],
Counter({"00": 2, "11": 2}),
[
{
"type": {"observable": ["z", "z"], "targets": [0, 1], "type": "sample"},
"value": [1, 1, 1, 1],
},
],
{"00": 2, "11": 2},
),
(
2,
qml.PauliZ(0),
None,
[[0, 0], [1, 1], [0, 0], [1, 1]],
Counter({"00": 2, "11": 2}),
[
{
"type": {"observable": ["z"], "targets": [0], "type": "sample"},
"value": [1, -1, 1, -1],
},
],
{1: 2, -1: 2},
),
(
2,
None,
[0],
[[0, 0], [1, 1], [0, 0], [1, 1]],
Counter({"00": 2, "11": 2}),
[
{
"type": {"observable": ["z", "z"], "targets": [0, 1], "type": "sample"},
"value": [1, 1, 1, 1],
},
],
{"0": 2, "1": 2},
),
(
3,
None,
[2],
[[0, 0, 0], [1, 1, 0], [0, 0, 0], [1, 1, 0]],
Counter({"000": 2, "110": 2}),
[
{
"type": {"observable": ["z"], "targets": [2], "type": "sample"},
"value": [1, 1, 1, 1],
},
],
{"0": 4},
),
],
)
def test_execute_counts(
mock_run,
num_wires,
op,
wires,
measurements,
measurement_counts,
result_types,
expected_result,
):
result = GateModelQuantumTaskResult.from_string(
json.dumps(
{
"braketSchemaHeader": {
"name": "braket.task_result.gate_model_task_result",
"version": "1",
},
"measurements": measurements,
"measurement_counts": measurement_counts,
"resultTypes": result_types,
"measuredQubits": [0, 1],
"taskMetadata": {
"braketSchemaHeader": {
"name": "braket.task_result.task_metadata",
"version": "1",
},
"id": "task_arn",
"shots": 4,
"deviceId": "default",
},
"additionalMetadata": {
"action": {
"braketSchemaHeader": {
"name": "braket.ir.openqasm.program",
"version": "1",
},
"source": "qubit[2] q; cnot q[0], q[1]; measure q;",
},
},
}
)
)

task = Mock()
task.result.return_value = result
mock_run.return_value = task

dev = _aws_device(wires=num_wires, shots=4)

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.counts(op=op, wires=wires)

results = dev.execute(circuit)

assert results == expected_result


def test_counts_all_outcomes_fails():
"""Tests that the calling counts with 'all_outcomes=True' raises an error"""
dev = _aws_device(wires=2, shots=4)

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.counts(all_outcomes=True)

does_not_support = "Unsupported return type: ObservableReturnTypes.AllCounts"
with pytest.raises(NotImplementedError, match=does_not_support):
dev.execute(circuit)


def test_sample_fails():
"""Tests that the calling sample raises an error"""
dev = _aws_device(wires=2, shots=4)

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.sample()

does_not_support = "Unsupported return type: ObservableReturnTypes.Sample"
with pytest.raises(NotImplementedError, match=does_not_support):
dev.execute(circuit)


def test_unsupported_return_type():
"""Tests that using an unsupported return type for measurement raises an error"""
dev = _aws_device(wires=2, shots=4)

mock_measurement = Mock()
mock_measurement.return_type = Enum("ObservableReturnTypes", {"Foo": "foo"}).Foo
mock_measurement.obs = qml.PauliZ(0)
mock_measurement.wires = qml.wires.Wires([0])

tape = qml.tape.QuantumTape(measurements=[mock_measurement])

does_not_support = "Unsupported return type: ObservableReturnTypes.Foo"
with pytest.raises(NotImplementedError, match=does_not_support):
dev.execute(tape)


@patch.object(AwsDevice, "type", new_callable=mock.PropertyMock)
@patch.object(AwsDevice, "name", new_callable=mock.PropertyMock)
def test_non_circuit_device(name_mock, type_mock):
Expand Down
Loading