diff --git a/src/braket/pennylane_plugin/braket_device.py b/src/braket/pennylane_plugin/braket_device.py index c6d6c87f..86f8f149 100644 --- a/src/braket/pennylane_plugin/braket_device.py +++ b/src/braket/pennylane_plugin/braket_device.py @@ -55,6 +55,7 @@ from pennylane import numpy as np from pennylane.gradients import param_shift from pennylane.measurements import ( + Counts, Expectation, MeasurementProcess, MeasurementTransform, @@ -78,7 +79,7 @@ from ._version import __version__ -RETURN_TYPES = [Expectation, Variance, Sample, Probability, State] +RETURN_TYPES = [Expectation, Variance, Sample, Probability, State, Counts] MIN_SIMULATOR_BILLED_MS = 3000 OBS_LIST = (qml.PauliX, qml.PauliY, qml.PauliZ) @@ -260,9 +261,7 @@ def statistics( results = [] for mp in measurements: if mp.return_type not in RETURN_TYPES: - raise QuantumFunctionError( - "Unsupported return type specified for observable {}".format(mp.obs.name) - ) + raise QuantumFunctionError("Unsupported return type: {}".format(mp.return_type)) results.append(self._get_statistic(braket_result, mp)) return results diff --git a/src/braket/pennylane_plugin/translation.py b/src/braket/pennylane_plugin/translation.py index feb69525..f75f5c0b 100644 --- a/src/braket/pennylane_plugin/translation.py +++ b/src/braket/pennylane_plugin/translation.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from collections import Counter from functools import partial, reduce, singledispatch from typing import Any, Optional, Union @@ -530,7 +531,7 @@ def get_adjoint_gradient_result_type( return AdjointGradient(observable=braket_observable, target=targets, parameters=parameters) -def translate_result_type( +def translate_result_type( # noqa: C901 measurement: MeasurementProcess, targets: list[int], supported_result_types: frozenset[str] ) -> Union[ResultType, tuple[ResultType, ...]]: """Translates a PennyLane ``MeasurementProcess`` into the corresponding Braket ``ResultType``. @@ -547,6 +548,7 @@ def translate_result_type( then this will return a result type for each term. """ return_type = measurement.return_type + observable = measurement.obs if return_type is ObservableReturnTypes.Probability: return Probability(targets) @@ -558,19 +560,24 @@ def translate_result_type( return DensityMatrix(targets) raise NotImplementedError(f"Unsupported return type: {return_type}") - if isinstance(measurement.obs, (Hamiltonian, qml.Hamiltonian)): + if isinstance(observable, (Hamiltonian, qml.Hamiltonian)): if return_type is ObservableReturnTypes.Expectation: return tuple( - Expectation(_translate_observable(term), term.wires) for term in measurement.obs.ops + Expectation(_translate_observable(term), term.wires) for term in observable.ops ) raise NotImplementedError(f"Return type {return_type} unsupported for Hamiltonian") - braket_observable = _translate_observable(measurement.obs) + if observable is None: + if return_type is ObservableReturnTypes.Counts: + return tuple(Sample(observables.Z(), target) for target in targets or measurement.wires) + raise NotImplementedError(f"Unsupported return type: {return_type}") + + braket_observable = _translate_observable(observable) if return_type is ObservableReturnTypes.Expectation: return Expectation(braket_observable, targets) elif return_type is ObservableReturnTypes.Variance: return Variance(braket_observable, targets) - elif return_type is ObservableReturnTypes.Sample: + elif return_type in (ObservableReturnTypes.Sample, ObservableReturnTypes.Counts): return Sample(braket_observable, targets) else: raise NotImplementedError(f"Unsupported return type: {return_type}") @@ -698,6 +705,19 @@ def translate_result( ag_result.value["gradient"][f"p_{i}"] for i in sorted(key_indices) ] + + if measurement.return_type is ObservableReturnTypes.Counts and observable is None: + if targets: + new_dict = {} + for key, value in braket_result.measurement_counts.items(): + new_key = "".join(key[i] for i in targets) + if new_key not in new_dict: + new_dict[new_key] = 0 + new_dict[new_key] += value + return new_dict + + return dict(braket_result.measurement_counts) + translated = translate_result_type(measurement, targets, supported_result_types) if isinstance(observable, (Hamiltonian, qml.Hamiltonian)): coeffs, _ = observable.terms() @@ -705,5 +725,7 @@ def translate_result( coeff * braket_result.get_value_by_result_type(result_type) for coeff, result_type in zip(coeffs, translated) ) + elif measurement.return_type is ObservableReturnTypes.Counts: + return dict(Counter(braket_result.get_value_by_result_type(translated))) else: return braket_result.get_value_by_result_type(translated) diff --git a/test/integ_tests/test_counts.py b/test/integ_tests/test_counts.py new file mode 100644 index 00000000..48c1eeec --- /dev/null +++ b/test/integ_tests/test_counts.py @@ -0,0 +1,87 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""Tests that counts are correctly computed in the plugin device""" + +import numpy as np +import pennylane as qml +import pytest + +np.random.seed(42) + + +@pytest.mark.parametrize("shots", [8192]) +class TestCounts: + """Tests for the count return type""" + + def test_counts_values(self, device, shots, tol): + """Tests if the result returned by counts have + the correct values + """ + dev = device(2) + + @qml.qnode(dev) + def circuit(): + qml.RX(np.pi / 3, wires=0) + return qml.counts() + + result = circuit().item() + + # The sample should only contain 00 and 10 + assert "00" in result + assert "10" in result + assert "11" not in result + assert "01" not in result + assert result["00"] + result["10"] == shots + assert np.allclose(result["00"] / shots, 0.75, **tol) + assert np.allclose(result["10"] / shots, 0.25, **tol) + + def test_counts_values_specify_target(self, device, shots, tol): + """Tests if the result returned by counts have + the correct values when specifying a target + """ + dev = device(2) + + @qml.qnode(dev) + def circuit(): + qml.RX(np.pi / 3, wires=0) + return qml.counts(wires=[0]) + + result = circuit().item() + + # The sample should only contain 00 and 10 + assert "0" in result + assert "1" in result + assert result["0"] + result["1"] == shots + assert np.allclose(result["0"] / shots, 0.75, **tol) + assert np.allclose(result["1"] / shots, 0.25, **tol) + + def test_counts_values_with_observable(self, device, shots, tol): + """Tests if the result returned by counts have + the correct values when specifying an observable + """ + dev = device(2) + + @qml.qnode(dev) + def circuit(): + qml.RX(np.pi / 3, wires=0) + return qml.counts(op=qml.PauliZ(wires=[0])) + + result = circuit().item() + + # The sample should only contain 00 and 10 + assert -1 in result + assert 1 in result + assert result[-1] + result[1] == shots + assert np.allclose(result[1] / shots, 0.75, **tol) + assert np.allclose(result[-1] / shots, 0.25, **tol) diff --git a/test/unit_tests/test_braket_device.py b/test/unit_tests/test_braket_device.py index 7d818570..6890dc86 100644 --- a/test/unit_tests/test_braket_device.py +++ b/test/unit_tests/test_braket_device.py @@ -12,6 +12,8 @@ # language governing permissions and limitations under the License. import json +from collections import Counter +from enum import Enum from typing import Any, Optional from unittest import mock from unittest.mock import Mock, PropertyMock, patch @@ -862,9 +864,8 @@ def test_parametrized_evolution_in_oqc_lucy_supported_ops(): def test_bad_statistics(): """Test if a QuantumFunctionError is raised for an invalid return type""" dev = _aws_device(wires=1, foo="bar") - observable = qml.Identity(wires=0) - tape = qml.tape.QuantumTape(measurements=[qml.counts(observable)]) - with pytest.raises(QuantumFunctionError, match="Unsupported return type specified"): + tape = qml.tape.QuantumTape(measurements=[qml.classical_shadow(wires=[0])]) + with pytest.raises(QuantumFunctionError, match="Unsupported return type:"): dev.statistics(None, tape.measurements) @@ -1233,6 +1234,171 @@ def test_execute_some_samples(mock_run): assert results[1] == 0.0 +@patch.object(AwsDevice, "run") +@pytest.mark.parametrize( + "num_wires, op, wires, measurements, measurement_counts, result_types, expected_result", + [ + ( + 2, + None, + None, + [[0, 0], [1, 1], [0, 0], [1, 1]], + Counter({"00": 2, "11": 2}), + [ + { + "type": {"observable": ["z", "z"], "targets": [0, 1], "type": "sample"}, + "value": [1, 1, 1, 1], + }, + ], + {"00": 2, "11": 2}, + ), + ( + 2, + qml.PauliZ(0), + None, + [[0, 0], [1, 1], [0, 0], [1, 1]], + Counter({"00": 2, "11": 2}), + [ + { + "type": {"observable": ["z"], "targets": [0], "type": "sample"}, + "value": [1, -1, 1, -1], + }, + ], + {1: 2, -1: 2}, + ), + ( + 2, + None, + [0], + [[0, 0], [1, 1], [0, 0], [1, 1]], + Counter({"00": 2, "11": 2}), + [ + { + "type": {"observable": ["z", "z"], "targets": [0, 1], "type": "sample"}, + "value": [1, 1, 1, 1], + }, + ], + {"0": 2, "1": 2}, + ), + ( + 3, + None, + [2], + [[0, 0, 0], [1, 1, 0], [0, 0, 0], [1, 1, 0]], + Counter({"000": 2, "110": 2}), + [ + { + "type": {"observable": ["z"], "targets": [2], "type": "sample"}, + "value": [1, 1, 1, 1], + }, + ], + {"0": 4}, + ), + ], +) +def test_execute_counts( + mock_run, + num_wires, + op, + wires, + measurements, + measurement_counts, + result_types, + expected_result, +): + result = GateModelQuantumTaskResult.from_string( + json.dumps( + { + "braketSchemaHeader": { + "name": "braket.task_result.gate_model_task_result", + "version": "1", + }, + "measurements": measurements, + "measurement_counts": measurement_counts, + "resultTypes": result_types, + "measuredQubits": [0, 1], + "taskMetadata": { + "braketSchemaHeader": { + "name": "braket.task_result.task_metadata", + "version": "1", + }, + "id": "task_arn", + "shots": 4, + "deviceId": "default", + }, + "additionalMetadata": { + "action": { + "braketSchemaHeader": { + "name": "braket.ir.openqasm.program", + "version": "1", + }, + "source": "qubit[2] q; cnot q[0], q[1]; measure q;", + }, + }, + } + ) + ) + + task = Mock() + task.result.return_value = result + mock_run.return_value = task + + dev = _aws_device(wires=num_wires, shots=4) + + with QuantumTape() as circuit: + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) + qml.counts(op=op, wires=wires) + + results = dev.execute(circuit) + + assert results == expected_result + + +def test_counts_all_outcomes_fails(): + """Tests that the calling counts with 'all_outcomes=True' raises an error""" + dev = _aws_device(wires=2, shots=4) + + with QuantumTape() as circuit: + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) + qml.counts(all_outcomes=True) + + does_not_support = "Unsupported return type: ObservableReturnTypes.AllCounts" + with pytest.raises(NotImplementedError, match=does_not_support): + dev.execute(circuit) + + +def test_sample_fails(): + """Tests that the calling sample raises an error""" + dev = _aws_device(wires=2, shots=4) + + with QuantumTape() as circuit: + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) + qml.sample() + + does_not_support = "Unsupported return type: ObservableReturnTypes.Sample" + with pytest.raises(NotImplementedError, match=does_not_support): + dev.execute(circuit) + + +def test_unsupported_return_type(): + """Tests that using an unsupported return type for measurement raises an error""" + dev = _aws_device(wires=2, shots=4) + + mock_measurement = Mock() + mock_measurement.return_type = Enum("ObservableReturnTypes", {"Foo": "foo"}).Foo + mock_measurement.obs = qml.PauliZ(0) + mock_measurement.wires = qml.wires.Wires([0]) + + tape = qml.tape.QuantumTape(measurements=[mock_measurement]) + + does_not_support = "Unsupported return type: ObservableReturnTypes.Foo" + with pytest.raises(NotImplementedError, match=does_not_support): + dev.execute(tape) + + @patch.object(AwsDevice, "type", new_callable=mock.PropertyMock) @patch.object(AwsDevice, "name", new_callable=mock.PropertyMock) def test_non_circuit_device(name_mock, type_mock): diff --git a/test/unit_tests/test_translation.py b/test/unit_tests/test_translation.py index 150762c0..d75aebc2 100644 --- a/test/unit_tests/test_translation.py +++ b/test/unit_tests/test_translation.py @@ -757,8 +757,7 @@ def test_translate_result_type_state_unimplemented(): def test_translate_result_type_unsupported_return(): """Tests if a NotImplementedError is raised by translate_result_type for an unknown return_type""" - obs = qml.Hadamard(wires=0) - tape = qml.tape.QuantumTape(measurements=[qml.counts(obs)]) + tape = qml.tape.QuantumTape(measurements=[qml.purity(wires=[0])]) with pytest.raises(NotImplementedError, match="Unsupported return type"): translate_result_type(tape.measurements[0], [0], frozenset())