From 8d088305fe620779ef31d81bd841019c3499f382 Mon Sep 17 00:00:00 2001 From: Kshitij Chhabra Date: Mon, 17 Aug 2020 11:21:33 -0700 Subject: [PATCH] Move casts to a shared method --- src/braket/aws/aws_quantum_task.py | 7 +---- .../tasks/gate_model_quantum_task_result.py | 31 +++++++++++++------ 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index 4590f8579..36433313c 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -426,12 +426,7 @@ def _format_result(result): @_format_result.register def _(result: GateModelTaskResult) -> GateModelQuantumTaskResult: - if result.resultTypes: - for result_type in result.resultTypes: - type = result_type.type.type - if type == "amplitude": - for state in result_type.value: - result_type.value[state] = complex(*result_type.value[state]) + GateModelQuantumTaskResult.cast_result_types(result) return GateModelQuantumTaskResult.from_object(result) diff --git a/src/braket/tasks/gate_model_quantum_task_result.py b/src/braket/tasks/gate_model_quantum_task_result.py index caabb9c5f..464d028b3 100644 --- a/src/braket/tasks/gate_model_quantum_task_result.py +++ b/src/braket/tasks/gate_model_quantum_task_result.py @@ -217,16 +217,7 @@ def from_string(result: str) -> GateModelQuantumTaskResult: in the result dict """ obj = GateModelTaskResult.parse_raw(result) - if obj.resultTypes: - for result_type in obj.resultTypes: - type = result_type.type.type - if type == "probability": - result_type.value = np.array(result_type.value) - elif type == "statevector": - result_type.value = np.array([complex(*value) for value in result_type.value]) - elif type == "amplitude": - for state in result_type.value: - result_type.value[state] = complex(*result_type.value[state]) + GateModelQuantumTaskResult.cast_result_types(obj) return GateModelQuantumTaskResult._from_object_internal(obj) @classmethod @@ -303,6 +294,26 @@ def _from_dict_internal_simulator_only(cls, result: GateModelTaskResult): values=values, ) + @staticmethod + def cast_result_types(gate_model_task_result: GateModelTaskResult) -> None: + """ + Casts the result types to the types expected by the SDK. + + Args: + gate_model_task_result (GateModelTaskResult): GateModelTaskResult representing the + results. + """ + if gate_model_task_result.resultTypes: + for result_type in gate_model_task_result.resultTypes: + type = result_type.type.type + if type == "probability": + result_type.value = np.array(result_type.value) + elif type == "statevector": + result_type.value = np.array([complex(*value) for value in result_type.value]) + elif type == "amplitude": + for state in result_type.value: + result_type.value[state] = complex(*result_type.value[state]) + @staticmethod def _calculate_result_types( ir_string: str, measurements: np.ndarray, measured_qubits: List[int]