Coverage for src/braket/aws/aws_quantum_task.py : 100%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
14from __future__ import annotations
16import asyncio
17import time
18from functools import singledispatch
19from logging import Logger, getLogger
20from typing import Any, Dict, Optional, Union
22import boto3
23from braket.annealing.problem import Problem
24from braket.aws.aws_session import AwsSession
25from braket.circuits.circuit import Circuit
26from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult, QuantumTask
29class AwsQuantumTask(QuantumTask):
30 """Amazon Braket implementation of a quantum task. A task can be a circuit or an annealing
31 problem."""
33 # TODO: Add API documentation that defines these states. Make it clear this is the contract.
34 NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"}
35 RESULTS_READY_STATES = {"COMPLETED"}
37 GATE_IR_TYPE = "jaqcd"
38 ANNEALING_IR_TYPE = "annealing"
39 DEFAULT_SHOTS = 1_000
41 DEFAULT_RESULTS_POLL_TIMEOUT = 120
42 DEFAULT_RESULTS_POLL_INTERVAL = 0.25
44 @staticmethod
45 def create(
46 aws_session: AwsSession,
47 device_arn: str,
48 task_specification: Union[Circuit, Problem],
49 s3_destination_folder: AwsSession.S3DestinationFolder,
50 shots: Optional[int] = None,
51 backend_parameters: Dict[str, Any] = None,
52 *args,
53 **kwargs,
54 ) -> AwsQuantumTask:
55 """AwsQuantumTask factory method that serializes a quantum task specification
56 (either a quantum circuit or annealing problem), submits it to Amazon Braket,
57 and returns back an AwsQuantumTask tracking the execution.
59 Args:
60 aws_session (AwsSession): AwsSession to connect to AWS with.
62 device_arn (str): The ARN of the quantum device.
64 task_specification (Union[Circuit, Problem]): The specification of the task
65 to run on device.
67 s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple, with bucket
68 for index 0 and key for index 1, that specifies the Amazon S3 bucket and folder
69 to store task results in.
71 shots (int): The number of times to run the task on the device. If the device is a
72 simulator, this implies the state is sampled N times, where N = `shots`. Default
73 shots = 1_000.
75 backend_parameters (Dict[str, Any]): Additional parameters to send to the device.
76 For example, for D-Wave:
77 `{"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}`
79 Returns:
80 AwsQuantumTask: AwsQuantumTask tracking the task execution on the device.
81 Note:
82 The following arguments are typically defined via clients of Device.
83 - `task_specification`
84 - `s3_destination_folder`
85 - `shots`
86 """
87 if len(s3_destination_folder) != 2:
88 raise ValueError(
89 "s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively."
90 )
92 create_task_kwargs = _create_common_params(
93 device_arn,
94 s3_destination_folder,
95 shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS,
96 )
97 return _create_internal(
98 task_specification,
99 aws_session,
100 create_task_kwargs,
101 backend_parameters or {},
102 *args,
103 **kwargs,
104 )
106 def __init__(
107 self,
108 arn: str,
109 aws_session: AwsSession = None,
110 poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT,
111 poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL,
112 logger: Logger = getLogger(__name__),
113 ):
114 """
115 Args:
116 arn (str): The ARN of the task.
117 aws_session (AwsSession, optional): The `AwsSession` for connecting to AWS services.
118 Default is `None`, in which case an `AwsSession` object will be created with the
119 region of the task.
120 poll_timeout_seconds (int): The polling timeout for result(), default is 120 seconds.
121 poll_interval_seconds (int): The polling interval for result(), default is 0.25
122 seconds.
123 logger (Logger): Logger object with which to write logs, such as task statuses
124 while waiting for task to be in a terminal state. Default is `getLogger(__name__)`
126 Examples:
127 >>> task = AwsQuantumTask(arn='task_arn')
128 >>> task.state()
129 'COMPLETED'
130 >>> result = task.result()
131 AnnealingQuantumTaskResult(...)
133 >>> task = AwsQuantumTask(arn='task_arn', poll_timeout_seconds=300)
134 >>> result = task.result()
135 GateModelQuantumTaskResult(...)
136 """
138 self._arn: str = arn
139 self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn(
140 task_arn=arn
141 )
142 self._poll_timeout_seconds = poll_timeout_seconds
143 self._poll_interval_seconds = poll_interval_seconds
144 self._logger = logger
146 self._metadata: Dict[str, Any] = {}
147 self._result: Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult] = None
148 try:
149 asyncio.get_event_loop()
150 except Exception as e:
151 self._logger.debug(e)
152 self._logger.info("No event loop found; creating new event loop")
153 asyncio.set_event_loop(asyncio.new_event_loop())
154 self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
156 @staticmethod
157 def _aws_session_for_task_arn(task_arn: str) -> AwsSession:
158 """
159 Get an AwsSession for the Task ARN. The AWS session should be in the region of the task.
161 Returns:
162 AwsSession: `AwsSession` object with default `boto_session` in task's region
163 """
164 task_region = task_arn.split(":")[3]
165 boto_session = boto3.Session(region_name=task_region)
166 return AwsSession(boto_session=boto_session)
168 @property
169 def id(self) -> str:
170 """str: The ARN of the quantum task."""
171 return self._arn
173 def cancel(self) -> None:
174 """Cancel the quantum task. This cancels the future and the task in Amazon Braket."""
175 self._future.cancel()
176 self._aws_session.cancel_quantum_task(self._arn)
178 def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]:
179 """
180 Get task metadata defined in Amazon Braket.
182 Args:
183 use_cached_value (bool, optional): If `True`, uses the value most recently retrieved
184 from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the
185 `GetQuantumTask` operation to retrieve metadata, which also updates the cached
186 value. Default = `False`.
187 Returns:
188 Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation.
189 If `use_cached_value` is `True`, Amazon Braket is not called and the most recently
190 retrieved value is used.
191 """
192 if not use_cached_value:
193 self._metadata = self._aws_session.get_quantum_task(self._arn)
194 return self._metadata
196 def state(self, use_cached_value: bool = False) -> str:
197 """
198 The state of the quantum task.
200 Args:
201 use_cached_value (bool, optional): If `True`, uses the value most recently retrieved
202 from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the
203 `GetQuantumTask` operation to retrieve metadata, which also updates the cached
204 value. Default = `False`.
205 Returns:
206 str: The value of `status` in `metadata()`. This is the value of the `status` key
207 in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`,
208 the value most recently returned from the `GetQuantumTask` operation is used.
209 See Also:
210 `metadata()`
211 """
212 return self.metadata(use_cached_value).get("status")
214 def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]:
215 """
216 Get the quantum task result by polling Amazon Braket to see if the task is completed.
217 Once the task is completed, the result is retrieved from S3 and returned as a
218 `GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult`
220 This method is a blocking thread call and synchronously returns a result. Call
221 async_result() if you require an asynchronous invocation.
222 Consecutive calls to this method return a cached result from the preceding request.
223 """
224 try:
225 return asyncio.get_event_loop().run_until_complete(self.async_result())
226 except asyncio.CancelledError:
227 # Future was cancelled, return whatever is in self._result if anything
228 self._logger.warning("Task future was cancelled")
229 return self._result
231 def async_result(self) -> asyncio.Task:
232 """
233 Get the quantum task result asynchronously. Consecutive calls to this method return
234 the result cached from the most recent request.
235 """
236 if self._future.done() and self._result is None: # timed out and no result
237 task_status = self.metadata()["status"]
238 if task_status in self.NO_RESULT_TERMINAL_STATES:
239 self._logger.warning(
240 f"Task is in terminal state {task_status} and no result is available"
241 )
242 # Return done future. Don't restart polling.
243 return self._future
244 else:
245 self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
246 return self._future
248 async def _create_future(self) -> asyncio.Task:
249 """
250 Wrap the `_wait_for_completion` coroutine inside a future-like object.
251 Invoking this method starts the coroutine and returns back the future-like object
252 that contains it. Note that this does not block on the coroutine to finish.
254 Returns:
255 asyncio.Task: An asyncio Task that contains the _wait_for_completion() coroutine.
256 """
257 return asyncio.create_task(self._wait_for_completion())
259 def _get_results_formatter(
260 self,
261 ) -> Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]:
262 """
263 Get results formatter based on irType of self.metadata()
265 Returns:
266 Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]:
267 function that deserializes a string into a results structure
268 """
269 current_metadata = self.metadata()
270 ir_type = current_metadata["irType"]
271 if ir_type == AwsQuantumTask.ANNEALING_IR_TYPE:
272 return AnnealingQuantumTaskResult.from_string
273 elif ir_type == AwsQuantumTask.GATE_IR_TYPE:
274 return GateModelQuantumTaskResult.from_string
275 else:
276 raise ValueError("Unknown IR type")
278 async def _wait_for_completion(self) -> GateModelQuantumTaskResult:
279 """
280 Waits for the quantum task to be completed, then returns the result from the S3 bucket.
281 Returns:
282 GateModelQuantumTaskResult: If the task is in the `AwsQuantumTask.RESULTS_READY_STATES`
283 state within the specified time limit, the result from the S3 bucket is loaded and
284 returned. `None` is returned if a timeout occurs or task state is in
285 `AwsQuantumTask.NO_RESULT_TERMINAL_STATES`.
286 Note:
287 Timeout and sleep intervals are defined in the constructor fields
288 `poll_timeout_seconds` and `poll_interval_seconds` respectively.
289 """
290 self._logger.debug(f"Task {self._arn}: start polling for completion")
291 start_time = time.time()
293 while (time.time() - start_time) < self._poll_timeout_seconds:
294 current_metadata = self.metadata()
295 task_status = current_metadata["status"]
296 self._logger.debug(f"Task {self._arn}: task status {task_status}")
297 if task_status in AwsQuantumTask.RESULTS_READY_STATES:
298 result_string = self._aws_session.retrieve_s3_object_body(
299 current_metadata["resultsS3Bucket"], current_metadata["resultsS3ObjectKey"]
300 )
301 self._result = self._get_results_formatter()(result_string)
302 return self._result
303 elif task_status in AwsQuantumTask.NO_RESULT_TERMINAL_STATES:
304 self._logger.warning(
305 f"Task is in terminal state {task_status}" + "and no result is available"
306 )
307 self._result = None
308 return None
309 else:
310 await asyncio.sleep(self._poll_interval_seconds)
312 # Timed out
313 self._logger.warning(
314 f"Task {self._arn}: polling for task completion timed out after "
315 + f"{time.time()-start_time} secs"
316 )
317 self._result = None
318 return None
320 def __repr__(self) -> str:
321 return f"AwsQuantumTask('id':{self.id})"
323 def __eq__(self, other) -> bool:
324 if isinstance(other, AwsQuantumTask):
325 return self.id == other.id
326 return NotImplemented
328 def __hash__(self) -> int:
329 return hash(self.id)
332@singledispatch
333def _create_internal(
334 task_specification: Union[Circuit, Problem],
335 aws_session: AwsSession,
336 create_task_kwargs: Dict[str, Any],
337 backend_parameters: Dict[str, Any],
338 *args,
339 **kwargs,
340) -> AwsQuantumTask:
341 raise TypeError("Invalid task specification type")
344@_create_internal.register
345def _(
346 circuit: Circuit,
347 aws_session: AwsSession,
348 create_task_kwargs: Dict[str, Any],
349 backend_parameters: Dict[str, Any],
350 *args,
351 **kwargs,
352) -> AwsQuantumTask:
353 create_task_kwargs.update(
354 {
355 "ir": circuit.to_ir().json(),
356 "irType": AwsQuantumTask.GATE_IR_TYPE,
357 "backendParameters": {"gateModelParameters": {"qubitCount": circuit.qubit_count}},
358 }
359 )
360 task_arn = aws_session.create_quantum_task(**create_task_kwargs)
361 return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)
364@_create_internal.register
365def _(
366 problem: Problem,
367 aws_session: AwsSession,
368 create_task_kwargs: Dict[str, Any],
369 backend_parameters: Dict[str, Any],
370 *args,
371 **kwargs,
372) -> AwsQuantumTask:
373 create_task_kwargs.update(
374 {
375 "ir": problem.to_ir().json(),
376 "irType": AwsQuantumTask.ANNEALING_IR_TYPE,
377 "backendParameters": {"annealingModelParameters": backend_parameters},
378 }
379 )
381 task_arn = aws_session.create_quantum_task(**create_task_kwargs)
382 return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)
385def _create_common_params(
386 device_arn: str, s3_destination_folder: AwsSession.S3DestinationFolder, shots: int
387) -> Dict[str, Any]:
388 return {
389 "backendArn": device_arn,
390 "resultsS3Bucket": s3_destination_folder[0],
391 "resultsS3Prefix": s3_destination_folder[1],
392 "shots": shots,
393 }