Coverage for src/braket/aws/aws_session.py : 100%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
14from typing import Any, Dict, NamedTuple
16import boto3
19class AwsSession(object):
20 """Manage interactions with AWS services."""
22 S3DestinationFolder = NamedTuple("S3DestinationFolder", [("bucket", str), ("key", int)])
24 BRAKET_ENDPOINTS = {
25 "us-west-2": "https://7ko20bz2m2.execute-api.us-west-2.amazonaws.com/V3",
26 "us-west-1": "https://fdoco1n1x7.execute-api.us-west-1.amazonaws.com/V3",
27 # "us-west-2": "https://xe15dbdvw6.execute-api.us-west-2.amazonaws.com/V3",
28 # "us-west-2": "https://kpg9e8yzsg.execute-api.us-west-2.amazonaws.com/V3",
29 "us-east-1": "https://kqjovr0n70.execute-api.us-east-1.amazonaws.com/V3",
30 }
32 # similar to sagemaker sdk:
33 # https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/session.py
34 def __init__(self, boto_session=None, braket_client=None):
35 """
36 Args:
37 boto_session: A boto3 session object
38 braket_client: A boto3 Braket client
40 Raises:
41 ValueError: If Braket is not available in the Region used for the boto3 session.
42 """
44 self.boto_session = boto_session or boto3.Session()
46 if braket_client:
47 self.braket_client = braket_client
48 else:
49 region = self.boto_session.region_name
50 endpoint = AwsSession.BRAKET_ENDPOINTS.get(region, None)
51 if not endpoint:
52 supported_regions = list(AwsSession.BRAKET_ENDPOINTS.keys())
53 raise ValueError(
54 f"No braket endpoint for {region}, supported Regions are {supported_regions}"
55 )
57 self.braket_client = self.boto_session.client("braket", endpoint_url=endpoint)
59 #
60 # Quantum Tasks
61 #
62 def cancel_quantum_task(self, arn: str) -> None:
63 """
64 Cancel the quantum task.
66 Args:
67 arn (str): The ARN of the quantum task to cancel.
68 """
69 self.braket_client.cancel_quantum_task(quantumTaskArn=arn)
71 def create_quantum_task(self, **boto3_kwargs) -> str:
72 """
73 Create a quantum task.
75 Args:
76 **boto3_kwargs: Keyword arguments for the Amazon Braket `CreateQuantumTask` operation.
78 Returns:
79 str: The ARN of the quantum task.
80 """
81 response = self.braket_client.create_quantum_task(**boto3_kwargs)
82 return response["quantumTaskArn"]
84 def get_quantum_task(self, arn: str) -> Dict[str, Any]:
85 """
86 Gets the quantum task.
88 Args:
89 arn (str): The ARN of the quantum task to cancel.
91 Returns:
92 Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation.
93 """
94 return self.braket_client.get_quantum_task(quantumTaskArn=arn)
96 def retrieve_s3_object_body(self, s3_bucket: str, s3_object_key: str) -> str:
97 """
98 Retrieve the S3 object body
100 Args:
101 s3_bucket (str): The S3 bucket name
102 s3_object_key (str): The S3 object key within the `s3_bucket`
104 Returns:
105 str: The body of the S3 object
106 """
107 s3 = self.boto_session.resource("s3")
108 obj = s3.Object(s3_bucket, s3_object_key)
109 return obj.get()["Body"].read().decode("utf-8")
111 # TODO: add in boto3 exception handling once we have exception types in API
112 def get_qpu_metadata(self, arn: str) -> Dict[str, Any]:
113 """
114 Calls the Amazon Braket `DescribeQpus` (`describe_qpus`) operation to retrieve
115 QPU metadata.
117 Args:
118 arn (str): The ARN of the QPU to retrieve metadata from
120 Returns:
121 Dict[str, Any]: QPU metadata
122 """
123 try:
124 response = self.braket_client.describe_qpus(qpuArns=[arn])
125 qpu_metadata = response.get("qpus")[0]
126 return qpu_metadata
127 except Exception as e:
128 raise e
130 # TODO: add in boto3 exception handling once we have exception types in API
131 def get_simulator_metadata(self, arn: str) -> Dict[str, Any]:
132 """
133 Calls the Amazon Braket `DescribeQuantumSimulators` (`describe_quantum_simulators`) to
134 retrieve simulator metadata
136 Args:
137 arn (str): The ARN of the simulator to retrieve metadata from
139 Returns:
140 Dict[str, Any]: Simulator metadata
141 """
142 try:
143 response = self.braket_client.describe_quantum_simulators(quantumSimulatorArns=[arn])
144 simulator_metadata = response.get("quantumSimulators")[0]
145 return simulator_metadata
146 except Exception as e:
147 raise e