Coverage for src/braket/aws/aws_session.py : 100%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
14from typing import Any, Dict, NamedTuple
16import boto3
19class AwsSession(object):
20 """Manage interactions with AWS services."""
22 S3DestinationFolder = NamedTuple("S3DestinationFolder", [("bucket", str), ("key", int)])
24 BRAKET_ENDPOINTS = {
25 "us-west-1": "https://fdoco1n1x7.execute-api.us-west-1.amazonaws.com/V3",
26 "us-west-2": "https://xe15dbdvw6.execute-api.us-west-2.amazonaws.com/V3",
27 "us-east-1": "https://kqjovr0n70.execute-api.us-east-1.amazonaws.com/V3",
28 }
30 # similar to sagemaker sdk:
31 # https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/session.py
32 def __init__(self, boto_session=None, braket_client=None):
33 """
34 Args:
35 boto_session: A boto3 session object
36 braket_client: A boto3 Braket client
38 Raises:
39 ValueError: If Braket is not available in the Region used for the boto3 session.
40 """
42 self.boto_session = boto_session or boto3.Session()
44 if braket_client:
45 self.braket_client = braket_client
46 else:
47 region = self.boto_session.region_name
48 endpoint = AwsSession.BRAKET_ENDPOINTS.get(region, None)
49 if not endpoint:
50 supported_regions = list(AwsSession.BRAKET_ENDPOINTS.keys())
51 raise ValueError(
52 f"No braket endpoint for {region}, supported Regions are {supported_regions}"
53 )
55 self.braket_client = self.boto_session.client("braket", endpoint_url=endpoint)
57 #
58 # Quantum Tasks
59 #
60 def cancel_quantum_task(self, arn: str) -> None:
61 """
62 Cancel the quantum task.
64 Args:
65 arn (str): The ARN of the quantum task to cancel.
66 """
67 self.braket_client.cancel_quantum_task(quantumTaskArn=arn)
69 def create_quantum_task(self, **boto3_kwargs) -> str:
70 """
71 Create a quantum task.
73 Args:
74 **boto3_kwargs: Keyword arguments for the Amazon Braket `CreateQuantumTask` operation.
76 Returns:
77 str: The ARN of the quantum task.
78 """
79 response = self.braket_client.create_quantum_task(**boto3_kwargs)
80 return response["quantumTaskArn"]
82 def get_quantum_task(self, arn: str) -> Dict[str, Any]:
83 """
84 Gets the quantum task.
86 Args:
87 arn (str): The ARN of the quantum task to cancel.
89 Returns:
90 Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation.
91 """
92 return self.braket_client.get_quantum_task(quantumTaskArn=arn)
94 def retrieve_s3_object_body(self, s3_bucket: str, s3_object_key: str) -> str:
95 """
96 Retrieve the S3 object body
98 Args:
99 s3_bucket (str): The S3 bucket name
100 s3_object_key (str): The S3 object key within the `s3_bucket`
102 Returns:
103 str: The body of the S3 object
104 """
105 s3 = self.boto_session.resource("s3")
106 obj = s3.Object(s3_bucket, s3_object_key)
107 return obj.get()["Body"].read().decode("utf-8")
109 # TODO: add in boto3 exception handling once we have exception types in API
110 def get_qpu_metadata(self, arn: str) -> Dict[str, Any]:
111 """
112 Calls the Amazon Braket `DescribeQpus` (`describe_qpus`) operation to retrieve
113 QPU metadata.
115 Args:
116 arn (str): The ARN of the QPU to retrieve metadata from
118 Returns:
119 Dict[str, Any]: QPU metadata
120 """
121 try:
122 response = self.braket_client.describe_qpus(qpuArns=[arn])
123 qpu_metadata = response.get("qpus")[0]
124 return qpu_metadata
125 except Exception as e:
126 raise e
128 # TODO: add in boto3 exception handling once we have exception types in API
129 def get_simulator_metadata(self, arn: str) -> Dict[str, Any]:
130 """
131 Calls the Amazon Braket `DescribeQuantumSimulators` (`describe_quantum_simulators`) to
132 retrieve simulator metadata
134 Args:
135 arn (str): The ARN of the simulator to retrieve metadata from
137 Returns:
138 Dict[str, Any]: Simulator metadata
139 """
140 try:
141 response = self.braket_client.describe_quantum_simulators(quantumSimulatorArns=[arn])
142 simulator_metadata = response.get("quantumSimulators")[0]
143 return simulator_metadata
144 except Exception as e:
145 raise e