Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Endpoints #8

Merged
merged 2 commits into from
Dec 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*.swp
*.idea
*.DS_Store
build_files.tar.gz

.ycm_extra_conf.py
.tox
Expand Down
62 changes: 46 additions & 16 deletions src/braket/aws/aws_qpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import boto3
from braket.aws.aws_qpu_arns import AwsQpuArns
from braket.aws.aws_quantum_task import AwsQuantumTask
from braket.aws.aws_session import AwsSession
from braket.devices.device import Device
Expand All @@ -22,17 +23,29 @@ class AwsQpu(Device):
Use this class to retrieve the latest metadata about the QPU, and run a circuit on the QPU.
"""

QPU_REGIONS = {AwsQpuArns.RIGETTI: ["us-west-1"], AwsQpuArns.IONQ: ["us-east-1"]}

def __init__(self, arn: str, aws_session=None):
"""
Args:
arn (str): QPU ARN, e.g. "arn:aws:aqx:::qpu:ionq"
aws_session (AwsSession, optional) aws_session: AWS session object. Default = None.

Raises:
ValueError: If unknown `arn` is supplied.

Note:
QPUs are physically located in specific AWS regions. If the supplied `aws_session`
is connected to a region that the QPU is not in then a cloned `aws_session`
will be created for the QPU region.

See `braket.aws.aws_qpu.AwsQpu.QPU_REGIONS` for the regions the QPUs are located in.
"""
super().__init__(
name=None, status=None, status_reason=None, supported_quantum_operations=None
)
self._arn = arn
self._aws_session = aws_session or AwsSession()
self._aws_session = self._aws_session_for_qpu(arn, aws_session)
self._qubit_count: int = None
# TODO: convert into graph object of qubits, type TBD
self._connectivity_graph = None
Expand All @@ -54,11 +67,11 @@ def run(self, *aws_quantum_task_args, **aws_quantum_task_kwargs) -> AwsQuantumTa

Examples:
>>> circuit = Circuit().h(0).cnot(0, 1)
>>> device = AwsQpu("ionq_arn")
>>> device = AwsQpu("arn:aws:aqx:::qpu:rigetti")
>>> device.run(circuit, ("bucket-foo", "key-bar"))

>>> circuit = Circuit().h(0).cnot(0, 1)
>>> device = AwsQpu("ionq_arn")
>>> device = AwsQpu("arn:aws:aqx:::qpu:rigetti")
>>> device.run(circuit=circuit, s3_destination_folder=("bucket-foo", "key-bar"))

See Also:
Expand All @@ -84,28 +97,45 @@ def refresh_metadata(self) -> None:

@property
def arn(self) -> str:
"""
Return arn of QPU

:rtype: str
"""
"""str: Return arn of QPU."""
return self._arn

@property
def qubit_count(self) -> int:
"""
Return maximum number of qubits that can be run on QPU

:rtype: int
"""
"""int: Return maximum number of qubits that can be run on QPU."""
return self._qubit_count

@property
def connectivity_graph(self):
"""Return connectivity graph of QPU."""
return self._connectivity_graph

def _aws_session_for_qpu(self, qpu_arn: str, aws_session: AwsSession) -> AwsSession:
"""
Return connectivity graph of QPU
Get an AwsSession for the QPU ARN. QPUs are only available in certain regions so any
supplied AwsSession in a region the QPU doesn't support will need to be adjusted.
"""
return self._connectivity_graph

qpu_regions = AwsQpu.QPU_REGIONS.get(qpu_arn, [])
if not qpu_regions:
raise ValueError(f"Unknown QPU {qpu_arn} was supplied.")

if aws_session:
if aws_session.boto_session.region_name in qpu_regions:
return aws_session
else:
creds = aws_session.boto_session.get_credentials()
boto_session = boto3.Session(
aws_access_key_id=creds.access_key,
aws_secret_access_key=creds.secret_key,
aws_session_token=creds.token,
profile_name=aws_session.boto_session.profile_name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this OK if profile_name wasn't specified by user?

Copy link
Contributor Author

@dbolt dbolt Dec 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it's safe. if profile_name wasn't specified then it will be "default".

(braket-sdk-3.7.4) laptop:/Volumes/workspaces/braket/braket-sdk  endpoints → echo $AWS_PROFILE

(braket-sdk-3.7.4) laptop:/Volumes/workspaces/braket/braket-sdk  endpoints → python
Python 3.7.4 (default, Oct  4 2019, 13:24:31) 
[Clang 10.0.0 (clang-1000.11.45.5)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import boto3
>>> session = boto3.Session()
>>> session.profile_name
'default'
>>> 

region_name=qpu_regions[0],
)
return AwsSession(boto_session=boto_session)
else:
boto_session = boto3.Session(region_name=qpu_regions[0])
return AwsSession(boto_session=boto_session)

def __repr__(self):
return "QPU('name': {}, 'arn': {})".format(self.name, self.arn)
Expand Down
2 changes: 2 additions & 0 deletions src/braket/aws/aws_quantum_simulator_arns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@

class AwsQuantumSimulatorArns(str, Enum):
QS1 = "arn:aws:aqx:::quantum-simulator:aqx:qs1"
QS2 = "arn:aws:aqx:::quantum-simulator:aqx:qs2"
QS3 = "arn:aws:aqx:::quantum-simulator:aqx:qs3"
42 changes: 29 additions & 13 deletions src/braket/aws/aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,44 @@

import boto3

# TODO: remove this once we have prod stages
PDX_BETA_URL = "https://4tetbmwz5k.execute-api.us-west-2.amazonaws.com/Prod"


class AwsSession(object):
S3DestinationFolder = NamedTuple("S3DestinationFolder", [("bucket", str), ("key", int)])
"""Manage interactions with AWS services."""

"""
Manage interactions with AWS services
S3DestinationFolder = NamedTuple("S3DestinationFolder", [("bucket", str), ("key", int)])

Args:
boto_session: boto3 session object
braket_client: boto3 braket client
"""
BRAKET_ENDPOINTS = {
"us-west-1": "https://fdoco1n1x7.execute-api.us-west-1.amazonaws.com/Prod",
"us-west-2": "https://xe15dbdvw6.execute-api.us-west-2.amazonaws.com/Prod",
"us-east-1": "https://kqjovr0n70.execute-api.us-east-1.amazonaws.com/Prod",
}

# similar to sagemaker sdk:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/session.py
def __init__(self, boto_session=None, braket_client=None):
"""
Args:
boto_session: boto3 session object
braket_client: boto3 braket client

Raises:
ValueError: If Amazon Braket does not exist for the `boto_session`'s region.
"""

self.boto_session = boto_session or boto3.Session()
self.braket_client = braket_client or self.boto_session.client(
"aqx", endpoint_url=PDX_BETA_URL
)

if braket_client:
self.braket_client = braket_client
else:
region = self.boto_session.region_name
endpoint = AwsSession.BRAKET_ENDPOINTS.get(region, None)
if not endpoint:
supported_regions = list(AwsSession.BRAKET_ENDPOINTS.keys())
raise ValueError(
f"No braket endpoint for {region}, supported regions are {supported_regions}"
)

self.braket_client = self.boto_session.client("aqx", endpoint_url=endpoint)

#
# Quantum Tasks
Expand Down
5 changes: 5 additions & 0 deletions test/integ_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@ def s3_prefix():

# strip off the filename extension and test/
return current_test_path.rsplit(".py")[0].replace("test/", "")


@pytest.fixture(scope="module")
def s3_destination_folder(s3_bucket, s3_prefix):
return AwsSession.S3DestinationFolder(s3_bucket, s3_prefix)
40 changes: 23 additions & 17 deletions test/integ_tests/test_device_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,34 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import test_common
from braket.aws import AwsQpu, AwsQuantumSimulator
import pytest
from braket.aws import AwsQpu, AwsQpuArns, AwsQuantumSimulator, AwsQuantumSimulatorArns


def test_aws_qpu_actual(aws_session):
qpu_arn = test_common.TEST_QPU_ARN
@pytest.mark.parametrize(
"qpu_arn,qpu_name", [(AwsQpuArns.RIGETTI, "Rigetti"), (AwsQpuArns.IONQ, "IonQ")]
)
def test_qpu_creation(qpu_arn, qpu_name, aws_session):
qpu = AwsQpu(qpu_arn, aws_session=aws_session)
assert qpu.arn == qpu_arn
assert qpu.connectivity_graph == {"0": ["1", "2"], "1": ["0", "2"], "2": ["0", "1"]}
assert qpu.name == "integ_test_qpu"
assert qpu.qubit_count == 16
assert qpu.status == "AVAILABLE"
assert qpu.status_reason == "Up and running"
assert qpu.supported_quantum_operations == ["CNOT", "H", "RZ", "RY", "RZ", "T"]
assert qpu.name == qpu_name


def test_get_simulator_metadata_actual(aws_session):
simulator_arn = test_common.TEST_SIMULATOR_ARN
def test_device_across_regions(aws_session):
# assert QPUs across different regions can be created using the same aws_session
AwsQpu(AwsQpuArns.RIGETTI, aws_session)
AwsQpu(AwsQpuArns.IONQ, aws_session)


@pytest.mark.parametrize(
"simulator_arn,simulator_name",
[
(AwsQuantumSimulatorArns.QS1, "quantum-simulator-1"),
(AwsQuantumSimulatorArns.QS2, "quantum-simulator-2"),
(AwsQuantumSimulatorArns.QS3, "quantum-simulator-3"),
],
)
def test_simulator_creation(simulator_arn, simulator_name, aws_session):
simulator = AwsQuantumSimulator(simulator_arn, aws_session=aws_session)
assert simulator.arn == simulator_arn
assert simulator.status_reason == "Under maintenance"
assert simulator.status == "UNAVAILABLE"
assert simulator.qubit_count == 30
assert simulator.name == "integ_test_simulator"
assert simulator.supported_quantum_operations == ["CNOT", "H", "RZ", "RY", "RZ", "T"]
assert simulator.name == simulator_name
53 changes: 40 additions & 13 deletions test/integ_tests/test_simulator_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,53 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import pytest
from braket.aws import AwsQuantumSimulator, AwsQuantumSimulatorArns
from braket.circuits import Circuit

# TODO: sad path once we have exception types in API

@pytest.mark.parametrize(
"simulator_arn", [AwsQuantumSimulatorArns.QS1, AwsQuantumSimulatorArns.QS3]
)
def test_bell_pair(simulator_arn, aws_session, s3_destination_folder):
device = AwsQuantumSimulator(simulator_arn, aws_session)
bell = Circuit().h(0).cnot(0, 1)
result = device.run(bell, s3_destination_folder, shots=750).result()

def test_simulator_quantum_task(aws_session, s3_bucket, s3_prefix):
device = AwsQuantumSimulator(AwsQuantumSimulatorArns.QS1, aws_session)
s3_destination_folder = (s3_bucket, s3_prefix)
assert 0.40 < result.measurement_probabilities["00"] < 0.60
assert 0.40 < result.measurement_probabilities["11"] < 0.60
assert len(result.measurements) == 750

bell = Circuit().h(0).cnot(0, 1)

circ = Circuit()
circ.add(bell)
circ.add(bell, [1, 2])
circ.add(bell, [2, 3])
@pytest.mark.parametrize(
"simulator_arn",
[ # TODO Uncomment out below once proper ordering fix has been applied to QS1
# AwsQuantumSimulatorArns.QS1,
AwsQuantumSimulatorArns.QS3
],
)
def test_qubit_ordering(simulator_arn, aws_session, s3_destination_folder):
device = AwsQuantumSimulator(simulator_arn, aws_session)

task = device.run(bell, s3_destination_folder)
# |110> should get back value of "110"
state_110 = Circuit().x(0).x(1).i(2)
result = device.run(state_110, s3_destination_folder).result()
assert result.measurement_counts.most_common(1)[0][0] == "110"

result = task.result()
# |001> should get back value of "001"
state_001 = Circuit().i(0).i(1).x(2)
result = device.run(state_001, s3_destination_folder).result()
assert result.measurement_counts.most_common(1)[0][0] == "001"

assert 0.40 < result.measurement_probabilities["00"] < 0.60
assert 0.40 < result.measurement_probabilities["11"] < 0.60

def test_qs2_quantum_task(aws_session, s3_destination_folder):
device = AwsQuantumSimulator(AwsQuantumSimulatorArns.QS2, aws_session)

bell = Circuit().h(range(8))
measurements = device.run(bell, s3_destination_folder, shots=1).result().measurements

# 1 shot
assert len(measurements) == 1

# 8 qubits
assert len(measurements[0]) == 8
Loading