Coverage for src/braket/aws/aws_quantum_task.py : 100%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
14from __future__ import annotations
16import asyncio
17import time
18from functools import singledispatch
19from logging import Logger, getLogger
20from typing import Any, Dict, Optional, Union
22import boto3
23from braket.annealing.problem import Problem
24from braket.aws.aws_session import AwsSession
25from braket.circuits.circuit import Circuit
26from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult, QuantumTask
29class AwsQuantumTask(QuantumTask):
30 """Amazon Braket implementation of a quantum task. A task can be a circuit or an annealing
31 problem."""
33 # TODO: Add API documentation that defines these states. Make it clear this is the contract.
34 NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"}
35 RESULTS_READY_STATES = {"COMPLETED"}
37 GATE_IR_TYPE = "jaqcd"
38 ANNEALING_IR_TYPE = "annealing"
40 DEFAULT_RESULTS_POLL_TIMEOUT = 120
41 DEFAULT_RESULTS_POLL_INTERVAL = 0.25
43 @staticmethod
44 def create(
45 aws_session: AwsSession,
46 device_arn: str,
47 task_specification: Union[Circuit, Problem],
48 s3_destination_folder: AwsSession.S3DestinationFolder,
49 shots: Optional[int] = None,
50 backend_parameters: Dict[str, Any] = None,
51 *args,
52 **kwargs,
53 ) -> AwsQuantumTask:
54 """AwsQuantumTask factory method that serializes a quantum task specification
55 (either a quantum circuit or annealing problem), submits it to Amazon Braket,
56 and returns back an AwsQuantumTask tracking the execution.
58 Args:
59 aws_session (AwsSession): AwsSession to connect to AWS with.
61 device_arn (str): The ARN of the quantum device.
63 task_specification (Union[Circuit, Problem]): The specification of the task
64 to run on device.
66 s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple, with bucket
67 for index 0 and key for index 1, that specifies the Amazon S3 bucket and folder
68 to store task results in.
70 shots (int): The number of times to run the task on the device. If the device is a
71 simulator, this implies the state is sampled N times, where N = `shots`. Default
72 shots = 1_000.
74 backend_parameters (Dict[str, Any]): Additional parameters to send to the device.
75 For example, for D-Wave:
76 `{"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}`
78 Returns:
79 AwsQuantumTask: AwsQuantumTask tracking the task execution on the device.
80 Note:
81 The following arguments are typically defined via clients of Device.
82 - `task_specification`
83 - `s3_destination_folder`
84 - `shots`
85 """
86 if len(s3_destination_folder) != 2:
87 raise ValueError(
88 "s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively."
89 )
91 create_task_kwargs = _create_common_params(
92 device_arn,
93 s3_destination_folder,
94 shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS,
95 )
96 return _create_internal(
97 task_specification,
98 aws_session,
99 create_task_kwargs,
100 backend_parameters or {},
101 *args,
102 **kwargs,
103 )
105 def __init__(
106 self,
107 arn: str,
108 aws_session: AwsSession = None,
109 poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT,
110 poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL,
111 logger: Logger = getLogger(__name__),
112 ):
113 """
114 Args:
115 arn (str): The ARN of the task.
116 aws_session (AwsSession, optional): The `AwsSession` for connecting to AWS services.
117 Default is `None`, in which case an `AwsSession` object will be created with the
118 region of the task.
119 poll_timeout_seconds (int): The polling timeout for result(), default is 120 seconds.
120 poll_interval_seconds (int): The polling interval for result(), default is 0.25
121 seconds.
122 logger (Logger): Logger object with which to write logs, such as task statuses
123 while waiting for task to be in a terminal state. Default is `getLogger(__name__)`
125 Examples:
126 >>> task = AwsQuantumTask(arn='task_arn')
127 >>> task.state()
128 'COMPLETED'
129 >>> result = task.result()
130 AnnealingQuantumTaskResult(...)
132 >>> task = AwsQuantumTask(arn='task_arn', poll_timeout_seconds=300)
133 >>> result = task.result()
134 GateModelQuantumTaskResult(...)
135 """
137 self._arn: str = arn
138 self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn(
139 task_arn=arn
140 )
141 self._poll_timeout_seconds = poll_timeout_seconds
142 self._poll_interval_seconds = poll_interval_seconds
143 self._logger = logger
145 self._metadata: Dict[str, Any] = {}
146 self._result: Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult] = None
147 try:
148 asyncio.get_event_loop()
149 except Exception as e:
150 self._logger.debug(e)
151 self._logger.info("No event loop found; creating new event loop")
152 asyncio.set_event_loop(asyncio.new_event_loop())
153 self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
155 @staticmethod
156 def _aws_session_for_task_arn(task_arn: str) -> AwsSession:
157 """
158 Get an AwsSession for the Task ARN. The AWS session should be in the region of the task.
160 Returns:
161 AwsSession: `AwsSession` object with default `boto_session` in task's region
162 """
163 task_region = task_arn.split(":")[3]
164 boto_session = boto3.Session(region_name=task_region)
165 return AwsSession(boto_session=boto_session)
167 @property
168 def id(self) -> str:
169 """str: The ARN of the quantum task."""
170 return self._arn
172 def cancel(self) -> None:
173 """Cancel the quantum task. This cancels the future and the task in Amazon Braket."""
174 self._future.cancel()
175 self._aws_session.cancel_quantum_task(self._arn)
177 def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]:
178 """
179 Get task metadata defined in Amazon Braket.
181 Args:
182 use_cached_value (bool, optional): If `True`, uses the value most recently retrieved
183 from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the
184 `GetQuantumTask` operation to retrieve metadata, which also updates the cached
185 value. Default = `False`.
186 Returns:
187 Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation.
188 If `use_cached_value` is `True`, Amazon Braket is not called and the most recently
189 retrieved value is used.
190 """
191 if not use_cached_value:
192 self._metadata = self._aws_session.get_quantum_task(self._arn)
193 return self._metadata
195 def state(self, use_cached_value: bool = False) -> str:
196 """
197 The state of the quantum task.
199 Args:
200 use_cached_value (bool, optional): If `True`, uses the value most recently retrieved
201 from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the
202 `GetQuantumTask` operation to retrieve metadata, which also updates the cached
203 value. Default = `False`.
204 Returns:
205 str: The value of `status` in `metadata()`. This is the value of the `status` key
206 in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`,
207 the value most recently returned from the `GetQuantumTask` operation is used.
208 See Also:
209 `metadata()`
210 """
211 return self.metadata(use_cached_value).get("status")
213 def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]:
214 """
215 Get the quantum task result by polling Amazon Braket to see if the task is completed.
216 Once the task is completed, the result is retrieved from S3 and returned as a
217 `GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult`
219 This method is a blocking thread call and synchronously returns a result. Call
220 async_result() if you require an asynchronous invocation.
221 Consecutive calls to this method return a cached result from the preceding request.
222 """
223 try:
224 return asyncio.get_event_loop().run_until_complete(self.async_result())
225 except asyncio.CancelledError:
226 # Future was cancelled, return whatever is in self._result if anything
227 self._logger.warning("Task future was cancelled")
228 return self._result
230 def async_result(self) -> asyncio.Task:
231 """
232 Get the quantum task result asynchronously. Consecutive calls to this method return
233 the result cached from the most recent request.
234 """
235 if self._future.done() and self._result is None: # timed out and no result
236 task_status = self.metadata()["status"]
237 if task_status in self.NO_RESULT_TERMINAL_STATES:
238 self._logger.warning(
239 f"Task is in terminal state {task_status} and no result is available"
240 )
241 # Return done future. Don't restart polling.
242 return self._future
243 else:
244 self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
245 return self._future
247 async def _create_future(self) -> asyncio.Task:
248 """
249 Wrap the `_wait_for_completion` coroutine inside a future-like object.
250 Invoking this method starts the coroutine and returns back the future-like object
251 that contains it. Note that this does not block on the coroutine to finish.
253 Returns:
254 asyncio.Task: An asyncio Task that contains the _wait_for_completion() coroutine.
255 """
256 return asyncio.create_task(self._wait_for_completion())
258 def _get_results_formatter(
259 self,
260 ) -> Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]:
261 """
262 Get results formatter based on irType of self.metadata()
264 Returns:
265 Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]:
266 function that deserializes a string into a results structure
267 """
268 current_metadata = self.metadata()
269 ir_type = current_metadata["irType"]
270 if ir_type == AwsQuantumTask.ANNEALING_IR_TYPE:
271 return AnnealingQuantumTaskResult.from_string
272 elif ir_type == AwsQuantumTask.GATE_IR_TYPE:
273 return GateModelQuantumTaskResult.from_string
274 else:
275 raise ValueError("Unknown IR type")
277 async def _wait_for_completion(self) -> GateModelQuantumTaskResult:
278 """
279 Waits for the quantum task to be completed, then returns the result from the S3 bucket.
280 Returns:
281 GateModelQuantumTaskResult: If the task is in the `AwsQuantumTask.RESULTS_READY_STATES`
282 state within the specified time limit, the result from the S3 bucket is loaded and
283 returned. `None` is returned if a timeout occurs or task state is in
284 `AwsQuantumTask.NO_RESULT_TERMINAL_STATES`.
285 Note:
286 Timeout and sleep intervals are defined in the constructor fields
287 `poll_timeout_seconds` and `poll_interval_seconds` respectively.
288 """
289 self._logger.debug(f"Task {self._arn}: start polling for completion")
290 start_time = time.time()
292 while (time.time() - start_time) < self._poll_timeout_seconds:
293 current_metadata = self.metadata()
294 task_status = current_metadata["status"]
295 self._logger.debug(f"Task {self._arn}: task status {task_status}")
296 if task_status in AwsQuantumTask.RESULTS_READY_STATES:
297 result_string = self._aws_session.retrieve_s3_object_body(
298 current_metadata["resultsS3Bucket"], current_metadata["resultsS3ObjectKey"]
299 )
300 self._result = self._get_results_formatter()(result_string)
301 return self._result
302 elif task_status in AwsQuantumTask.NO_RESULT_TERMINAL_STATES:
303 self._logger.warning(
304 f"Task is in terminal state {task_status}" + "and no result is available"
305 )
306 self._result = None
307 return None
308 else:
309 await asyncio.sleep(self._poll_interval_seconds)
311 # Timed out
312 self._logger.warning(
313 f"Task {self._arn}: polling for task completion timed out after "
314 + f"{time.time()-start_time} secs"
315 )
316 self._result = None
317 return None
319 def __repr__(self) -> str:
320 return f"AwsQuantumTask('id':{self.id})"
322 def __eq__(self, other) -> bool:
323 if isinstance(other, AwsQuantumTask):
324 return self.id == other.id
325 return NotImplemented
327 def __hash__(self) -> int:
328 return hash(self.id)
331@singledispatch
332def _create_internal(
333 task_specification: Union[Circuit, Problem],
334 aws_session: AwsSession,
335 create_task_kwargs: Dict[str, Any],
336 backend_parameters: Dict[str, Any],
337 *args,
338 **kwargs,
339) -> AwsQuantumTask:
340 raise TypeError("Invalid task specification type")
343@_create_internal.register
344def _(
345 circuit: Circuit,
346 aws_session: AwsSession,
347 create_task_kwargs: Dict[str, Any],
348 backend_parameters: Dict[str, Any],
349 *args,
350 **kwargs,
351) -> AwsQuantumTask:
352 create_task_kwargs.update(
353 {
354 "ir": circuit.to_ir().json(),
355 "irType": AwsQuantumTask.GATE_IR_TYPE,
356 "backendParameters": {"gateModelParameters": {"qubitCount": circuit.qubit_count}},
357 }
358 )
359 task_arn = aws_session.create_quantum_task(**create_task_kwargs)
360 return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)
363@_create_internal.register
364def _(
365 problem: Problem,
366 aws_session: AwsSession,
367 create_task_kwargs: Dict[str, Any],
368 backend_parameters: Dict[str, Any],
369 *args,
370 **kwargs,
371) -> AwsQuantumTask:
372 create_task_kwargs.update(
373 {
374 "ir": problem.to_ir().json(),
375 "irType": AwsQuantumTask.ANNEALING_IR_TYPE,
376 "backendParameters": {"annealingModelParameters": backend_parameters},
377 }
378 )
380 task_arn = aws_session.create_quantum_task(**create_task_kwargs)
381 return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)
384def _create_common_params(
385 device_arn: str, s3_destination_folder: AwsSession.S3DestinationFolder, shots: int
386) -> Dict[str, Any]:
387 return {
388 "backendArn": device_arn,
389 "resultsS3Bucket": s3_destination_folder[0],
390 "resultsS3Prefix": s3_destination_folder[1],
391 "shots": shots,
392 }