-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add simulator module with mcm simulator implementation
- Loading branch information
Showing
12 changed files
with
763 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
|
||
""" | ||
TODO: Module description | ||
TODO: Example usage | ||
.. code-block:: python | ||
# Python code here | ||
""" | ||
|
||
from .simulator import McmSimulator # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
|
||
from functools import singledispatch | ||
from typing import Any, Union | ||
|
||
import numpy as np | ||
from braket.default_simulator.openqasm._helpers.casting import convert_bool_array_to_string | ||
from braket.default_simulator.openqasm.parser.openqasm_ast import ( | ||
ArrayLiteral, | ||
BitstringLiteral, | ||
BooleanLiteral, | ||
FloatLiteral, | ||
IntegerLiteral, | ||
) | ||
|
||
LiteralType = Union[BooleanLiteral, IntegerLiteral, FloatLiteral, ArrayLiteral, BitstringLiteral] | ||
|
||
|
||
@singledispatch | ||
def convert_to_output(value: LiteralType) -> Any: | ||
raise TypeError(f"converting {value} to output") | ||
|
||
|
||
@convert_to_output.register(IntegerLiteral) | ||
@convert_to_output.register(FloatLiteral) | ||
@convert_to_output.register(BooleanLiteral) | ||
@convert_to_output.register(BitstringLiteral) | ||
def _(value): | ||
return value.value | ||
|
||
|
||
@convert_to_output.register(BitstringLiteral) | ||
def _(value): | ||
return np.array(np.binary_repr(value.value, value.width)) | ||
|
||
|
||
@convert_to_output.register | ||
def _(value: ArrayLiteral): | ||
if isinstance(value.values[0], BooleanLiteral): | ||
return convert_bool_array_to_string(value) | ||
return np.array([convert_to_output(x) for x in value.values]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
|
||
import itertools | ||
from collections.abc import Iterable | ||
|
||
import numpy as np | ||
|
||
|
||
def measurement_sample(prob: float, target_count: int) -> tuple[int]: | ||
"""_summary_ | ||
Args: | ||
prob (float): _description_ | ||
target_count (int): _description_ | ||
Returns: | ||
tuple[int]: _description_ | ||
""" | ||
basis_states = np.array(list(itertools.product([0, 1], repeat=target_count))) | ||
outcome_idx = np.random.choice(list(range(2**target_count)), p=prob) | ||
return tuple(basis_states[outcome_idx]) | ||
|
||
|
||
def measurement_collapse_dm( | ||
dm_tensor: np.ndarray, targets: Iterable[int], outcomes: np.ndarray | ||
) -> np.ndarray: | ||
"""_summary_ | ||
Args: | ||
dm_tensor (np.ndarray): _description_ | ||
targets (Iterable[int]): _description_ | ||
outcomes (np.ndarray): _description_ | ||
Returns: | ||
np.ndarray: _description_ | ||
""" | ||
# TODO: This needs to be modified to not delete qubits | ||
|
||
# move the target qubit to the front of axes | ||
qubit_count = int(np.log2(dm_tensor.shape[0])) | ||
unused_idxs = [idx for idx in range(qubit_count) if idx not in targets] | ||
unused_idxs = [ | ||
p + i * qubit_count for i in range(2) for p in unused_idxs | ||
] # convert indices to dm form | ||
target_indx = [ | ||
p + i * qubit_count for i in range(2) for p in targets | ||
] # convert indices to dm form | ||
permutation = target_indx + unused_idxs | ||
inverse_permutation = np.argsort(permutation) | ||
|
||
# collapse the density matrix based on measuremnt outcome | ||
outcomes = tuple(i for _ in range(2) for i in outcomes) | ||
new_dm_tensor = np.zeros_like(dm_tensor) | ||
new_dm_tensor[outcomes] = np.transpose(dm_tensor, permutation)[outcomes] | ||
new_dm_tensor = np.transpose(new_dm_tensor, inverse_permutation) | ||
|
||
# normalize | ||
new_trace = np.trace(np.reshape(new_dm_tensor, (2**qubit_count, 2**qubit_count))) | ||
new_dm_tensor = new_dm_tensor / new_trace | ||
return new_dm_tensor | ||
|
||
|
||
def measurement_collapse_sv( | ||
state_vector: np.ndarray, targets: Iterable[int], outcome: np.ndarray | ||
) -> np.ndarray: | ||
"""_summary_ | ||
Args: | ||
state_vector (np.ndarray): _description_ | ||
targets (Iterable[int]): _description_ | ||
outcome (np.ndarray): _description_ | ||
Returns: | ||
np.ndarray: _description_ | ||
""" | ||
qubit_count = int(np.log2(state_vector.size)) | ||
state_tensor = state_vector.reshape([2] * qubit_count) | ||
for qubit, measurement in zip(targets, outcome): | ||
state_tensor[(slice(None),) * qubit + (int(not measurement),)] = 0 | ||
|
||
state_tensor /= np.linalg.norm(state_tensor) | ||
return state_tensor.flatten() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
|
||
from copy import deepcopy | ||
from functools import singledispatchmethod | ||
from logging import Logger | ||
from typing import Any, List, Optional, Union | ||
|
||
from braket.default_simulator.openqasm._helpers.casting import cast_to, wrap_value_into_literal | ||
from braket.default_simulator.openqasm.interpreter import Interpreter | ||
from braket.default_simulator.openqasm.parser.openqasm_ast import ( | ||
ArrayLiteral, | ||
BitType, | ||
BooleanLiteral, | ||
ClassicalDeclaration, | ||
IndexedIdentifier, | ||
IODeclaration, | ||
IOKeyword, | ||
QASMNode, | ||
QuantumMeasurement, | ||
QuantumMeasurementStatement, | ||
QuantumReset, | ||
QubitDeclaration, | ||
) | ||
from braket.default_simulator.openqasm.parser.openqasm_parser import parse | ||
from braket.default_simulator.simulation import Simulation | ||
from openqasm3.ast import IntegerLiteral | ||
|
||
from autoqasm.simulator.program_context import McmProgramContext | ||
|
||
|
||
class NativeInterpreter(Interpreter): | ||
def __init__( | ||
self, | ||
simulation: Simulation, | ||
context: Optional[McmProgramContext] = None, | ||
logger: Optional[Logger] = None, | ||
): | ||
self.simulation = simulation | ||
context = context or McmProgramContext() | ||
super().__init__(context, logger) | ||
|
||
def simulate( | ||
self, | ||
source: str, | ||
inputs: Optional[dict[str, Any]] = None, | ||
is_file: bool = False, | ||
shots: int = 1, | ||
) -> dict[str, Any]: | ||
"""_summary_ | ||
Args: | ||
source (str): _description_ | ||
inputs (Optional[dict[str, Any]]): _description_. Defaults to None. | ||
is_file (bool): _description_. Defaults to False. | ||
shots (int): _description_. Defaults to 1. | ||
Returns: | ||
dict[str, Any]: _description_ | ||
""" | ||
if inputs: | ||
self.context.load_inputs(inputs) | ||
|
||
if is_file: | ||
with open(source, encoding="utf-8", mode="r") as f: | ||
source = f.read() | ||
|
||
program = parse(source) | ||
for _ in range(shots): | ||
program_copy = deepcopy(program) | ||
self.visit(program_copy) | ||
self.context.save_output_values() | ||
self.context.num_qubits = 0 | ||
self.simulation.reset() | ||
return self.context.outputs | ||
|
||
@singledispatchmethod | ||
def visit(self, node: Union[QASMNode, List[QASMNode]]) -> Optional[QASMNode]: | ||
"""Generic visit function for an AST node""" | ||
return super().visit(node) | ||
|
||
@visit.register | ||
def _(self, node: QubitDeclaration) -> None: | ||
self.logger.debug(f"Qubit declaration: {node}") | ||
size = self.visit(node.size).value if node.size else 1 | ||
self.context.add_qubits(node.qubit.name, size) | ||
self.simulation.add_qubits(size) | ||
|
||
@visit.register | ||
def _(self, node: QuantumMeasurement) -> Union[BooleanLiteral, ArrayLiteral]: | ||
self.logger.debug(f"Quantum measurement: {node}") | ||
self.simulation.evolve(self.context.pop_instructions()) | ||
targets = self.context.get_qubits(self.visit(node.qubit)) | ||
outcome = self.simulation.measure(targets) | ||
if len(targets) > 1 or ( | ||
isinstance(node.qubit, IndexedIdentifier) | ||
and not len(node.qubit.indices[0]) == 1 | ||
and isinstance(node.qubit.indices[0], IntegerLiteral) | ||
): | ||
return ArrayLiteral([BooleanLiteral(x) for x in outcome]) | ||
return BooleanLiteral(outcome[0]) | ||
|
||
@visit.register | ||
def _(self, node: QuantumMeasurementStatement) -> Union[BooleanLiteral, ArrayLiteral]: | ||
self.logger.debug(f"Quantum measurement statement: {node}") | ||
outcome = self.visit(node.measure) | ||
current_value = self.context.get_value_by_identifier(node.target) | ||
result_type = ( | ||
BooleanLiteral | ||
if isinstance(current_value, BooleanLiteral) or current_value is None | ||
else BitType(size=IntegerLiteral(len(current_value.values))) | ||
) | ||
value = cast_to(result_type, outcome) | ||
self.context.update_value(node.target, value) | ||
|
||
@visit.register | ||
def _(self, node: QuantumReset) -> None: | ||
self.logger.debug(f"Quantum reset: {node}") | ||
self.simulation.evolve(self.context.pop_instructions()) | ||
targets = self.context.get_qubits(self.visit(node.qubits)) | ||
outcome = self.simulation.measure(targets) | ||
for qubit, result in zip(targets, outcome): | ||
if result: | ||
self.simulation.flip(qubit) | ||
|
||
@visit.register | ||
def _(self, node: IODeclaration) -> None: | ||
self.logger.debug(f"IO Declaration: {node}") | ||
if node.io_identifier == IOKeyword.output: | ||
if node.identifier.name not in self.context.outputs: | ||
self.context.add_output(node.identifier.name) | ||
self.context.declare_variable( | ||
node.identifier.name, | ||
node.type, | ||
) | ||
else: # IOKeyword.input: | ||
if node.identifier.name not in self.context.inputs: | ||
raise NameError(f"Missing input variable '{node.identifier.name}'.") | ||
init_value = wrap_value_into_literal(self.context.inputs[node.identifier.name]) | ||
declaration = ClassicalDeclaration(node.type, node.identifier, init_value) | ||
self.visit(declaration) |
Oops, something went wrong.