From 9c787ec65eb2127abd8abd13585f295d215f7841 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 1 May 2024 15:38:51 -0400 Subject: [PATCH] fix: Force StateVector and DensityMatrix values to be ndarrays and test --- src/braket/simulator_v2/simulator.py | 37 +++++++++++++++++++ .../test_density_matrix_simulator.py | 33 ++++++++++++++++- .../test_state_vector_simulator.py | 32 ++++++++++++++++ 3 files changed, 101 insertions(+), 1 deletion(-) diff --git a/src/braket/simulator_v2/simulator.py b/src/braket/simulator_v2/simulator.py index cd4f5e2..58f9962 100644 --- a/src/braket/simulator_v2/simulator.py +++ b/src/braket/simulator_v2/simulator.py @@ -1,5 +1,6 @@ import sys +import numpy as np from braket.default_simulator.operation_helpers import from_braket_instruction from braket.default_simulator.result_types import TargetedResultType from braket.default_simulator.simulator import BaseLocalSimulator @@ -8,7 +9,9 @@ GateModelSimulatorDeviceCapabilities, GateModelSimulatorDeviceParameters, ) +from braket.ir.jaqcd import DensityMatrix from braket.ir.jaqcd import Program as JaqcdProgram +from braket.ir.jaqcd import StateVector from braket.ir.openqasm import Program as OpenQASMProgram from braket.task_result import GateModelTaskResult @@ -88,6 +91,16 @@ def run_jaqcd( ) r = jl.simulate(self._device, [circuit_ir], qubit_count, shots) r.additionalMetadata.action = circuit_ir + if not shots: + # need to convert `list` value for `statevector` + # and `densitymatrix` result types to `np.ndarray` + for result_ind, result_type in enumerate(r.resultTypes): + if isinstance(result_type.type, StateVector) or isinstance( + result_type.type, DensityMatrix + ): + r.resultTypes[result_ind].value = np.asarray( + r.resultTypes[result_ind].value + ) return r def run_openqasm( @@ -156,6 +169,16 @@ def run_openqasm( # attach the result types if shots: r.resultTypes = results + else: + # need to convert `list` value for `statevector` + # and `densitymatrix` result types to `np.ndarray` + for result_ind, result_type in enumerate(r.resultTypes): + if isinstance(result_type.type, StateVector) or isinstance( + result_type.type, DensityMatrix + ): + r.resultTypes[result_ind].value = np.asarray( + r.resultTypes[result_ind].value + ) return r @property @@ -467,6 +490,13 @@ def run_jaqcd( ) r = jl.simulate(self._device, [circuit_ir], qubit_count, shots) r.additionalMetadata.action = circuit_ir + if not shots: + # need to convert `list` value for `densitymatrix` result type to `np.ndarray` + for result_ind, result_type in enumerate(r.resultTypes): + if isinstance(result_type.type, DensityMatrix): + r.resultTypes[result_ind].value = np.asarray( + r.resultTypes[result_ind].value + ) return r def run_openqasm( @@ -534,6 +564,13 @@ def run_openqasm( # attach the result types if shots: r.resultTypes = results + else: + # need to convert `list` value for `densitymatrix` result type to `np.ndarray` + for result_ind, result_type in enumerate(r.resultTypes): + if isinstance(result_type.type, DensityMatrix): + r.resultTypes[result_ind].value = np.asarray( + r.resultTypes[result_ind].value + ) return r """A simulator meant to run directly on the user's machine using a Julia backend. diff --git a/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator.py b/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator.py index c2d68f4..730ab63 100644 --- a/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator.py +++ b/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator.py @@ -22,7 +22,7 @@ GateModelSimulatorDeviceCapabilities, GateModelSimulatorDeviceParameters, ) -from braket.ir.jaqcd import Expectation +from braket.ir.jaqcd import DensityMatrix, Expectation from braket.ir.jaqcd import Program as JaqcdProgram from braket.ir.openqasm import Program as OpenQASMProgram from braket.task_result import AdditionalMetadata, TaskMetadata @@ -846,3 +846,34 @@ def test_measure_targets(): assert 400 < np.sum(measurements, axis=0)[0] < 600 assert len(measurements[0]) == 1 assert result.measuredQubits == [0] + + +@pytest.mark.parametrize( + "jaqcd_string, oq3_pragma, jaqcd_type", + [ + ["densitymatrix", "density_matrix", DensityMatrix()], + ], +) +def test_simulator_analytic_value_type(jaqcd_string, oq3_pragma, jaqcd_type): + simulator = DensityMatrixSimulator() + jaqcd = JaqcdProgram.parse_raw( + json.dumps( + { + "instructions": [{"type": "h", "target": 0}], + "results": [{"type": jaqcd_string}], + } + ) + ) + qasm = OpenQASMProgram( + source=f""" + qubit q; + h q; + #pragma braket result {oq3_pragma} + """ + ) + result = simulator.run(jaqcd, qubit_count=2, shots=0) + assert result.resultTypes[0].type == jaqcd_type + assert isinstance(result.resultTypes[0].value, np.ndarray) + result = simulator.run(qasm, shots=0) + assert result.resultTypes[0].type == jaqcd_type + assert isinstance(result.resultTypes[0].value, np.ndarray) diff --git a/test/unit_tests/braket/simulator_v2/test_state_vector_simulator.py b/test/unit_tests/braket/simulator_v2/test_state_vector_simulator.py index 0a24dbe..dc99b3c 100644 --- a/test/unit_tests/braket/simulator_v2/test_state_vector_simulator.py +++ b/test/unit_tests/braket/simulator_v2/test_state_vector_simulator.py @@ -1410,3 +1410,35 @@ def test_rotation_parameter_expressions(operation, state_vector): result = simulator.run(OpenQASMProgram(source=qasm), shots=0) assert result.resultTypes[0].type == StateVector() assert np.allclose(result.resultTypes[0].value, np.array(state_vector)) + + +@pytest.mark.parametrize( + "jaqcd_string, oq3_pragma, jaqcd_type", + [ + ["statevector", "state_vector", StateVector()], + ["densitymatrix", "density_matrix", DensityMatrix()], + ], +) +def test_simulator_analytic_value_type(jaqcd_string, oq3_pragma, jaqcd_type): + simulator = StateVectorSimulator() + jaqcd = JaqcdProgram.parse_raw( + json.dumps( + { + "instructions": [{"type": "h", "target": 0}], + "results": [{"type": jaqcd_string}], + } + ) + ) + qasm = OpenQASMProgram( + source=f""" + qubit q; + h q; + #pragma braket result {oq3_pragma} + """ + ) + result = simulator.run(jaqcd, qubit_count=2, shots=0) + assert result.resultTypes[0].type == jaqcd_type + assert isinstance(result.resultTypes[0].value, np.ndarray) + result = simulator.run(qasm, shots=0) + assert result.resultTypes[0].type == jaqcd_type + assert isinstance(result.resultTypes[0].value, np.ndarray)