Skip to content

Commit

Permalink
Move casts to a shared method
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijc committed Aug 17, 2020
1 parent 7104e53 commit 8d08830
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
7 changes: 1 addition & 6 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,7 @@ def _format_result(result):

@_format_result.register
def _(result: GateModelTaskResult) -> GateModelQuantumTaskResult:
if result.resultTypes:
for result_type in result.resultTypes:
type = result_type.type.type
if type == "amplitude":
for state in result_type.value:
result_type.value[state] = complex(*result_type.value[state])
GateModelQuantumTaskResult.cast_result_types(result)
return GateModelQuantumTaskResult.from_object(result)


Expand Down
31 changes: 21 additions & 10 deletions src/braket/tasks/gate_model_quantum_task_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,7 @@ def from_string(result: str) -> GateModelQuantumTaskResult:
in the result dict
"""
obj = GateModelTaskResult.parse_raw(result)
if obj.resultTypes:
for result_type in obj.resultTypes:
type = result_type.type.type
if type == "probability":
result_type.value = np.array(result_type.value)
elif type == "statevector":
result_type.value = np.array([complex(*value) for value in result_type.value])
elif type == "amplitude":
for state in result_type.value:
result_type.value[state] = complex(*result_type.value[state])
GateModelQuantumTaskResult.cast_result_types(obj)
return GateModelQuantumTaskResult._from_object_internal(obj)

@classmethod
Expand Down Expand Up @@ -303,6 +294,26 @@ def _from_dict_internal_simulator_only(cls, result: GateModelTaskResult):
values=values,
)

@staticmethod
def cast_result_types(gate_model_task_result: GateModelTaskResult) -> None:
"""
Casts the result types to the types expected by the SDK.
Args:
gate_model_task_result (GateModelTaskResult): GateModelTaskResult representing the
results.
"""
if gate_model_task_result.resultTypes:
for result_type in gate_model_task_result.resultTypes:
type = result_type.type.type
if type == "probability":
result_type.value = np.array(result_type.value)
elif type == "statevector":
result_type.value = np.array([complex(*value) for value in result_type.value])
elif type == "amplitude":
for state in result_type.value:
result_type.value[state] = complex(*result_type.value[state])

@staticmethod
def _calculate_result_types(
ir_string: str, measurements: np.ndarray, measured_qubits: List[int]
Expand Down

0 comments on commit 8d08830

Please sign in to comment.