# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import annotations
import asyncio
import time
from functools import singledispatch
from logging import Logger, getLogger
from typing import Any, Dict, Optional, Union
import boto3
from braket.annealing.problem import Problem
from braket.aws.aws_session import AwsSession
from braket.circuits.circuit import Circuit
from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult, QuantumTask
[docs]class AwsQuantumTask(QuantumTask):
"""Amazon Braket implementation of a quantum task. A task can be a circuit or an annealing
problem."""
# TODO: Add API documentation that defines these states. Make it clear this is the contract.
NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"}
RESULTS_READY_STATES = {"COMPLETED"}
GATE_IR_TYPE = "jaqcd"
ANNEALING_IR_TYPE = "annealing"
DEFAULT_RESULTS_POLL_TIMEOUT = 120
DEFAULT_RESULTS_POLL_INTERVAL = 0.25
[docs] @staticmethod
def create(
aws_session: AwsSession,
device_arn: str,
task_specification: Union[Circuit, Problem],
s3_destination_folder: AwsSession.S3DestinationFolder,
shots: Optional[int] = None,
backend_parameters: Dict[str, Any] = None,
*args,
**kwargs,
) -> AwsQuantumTask:
"""AwsQuantumTask factory method that serializes a quantum task specification
(either a quantum circuit or annealing problem), submits it to Amazon Braket,
and returns back an AwsQuantumTask tracking the execution.
Args:
aws_session (AwsSession): AwsSession to connect to AWS with.
device_arn (str): The ARN of the quantum device.
task_specification (Union[Circuit, Problem]): The specification of the task
to run on device.
s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple, with bucket
for index 0 and key for index 1, that specifies the Amazon S3 bucket and folder
to store task results in.
shots (int): The number of times to run the task on the device. If the device is a
simulator, this implies the state is sampled N times, where N = `shots`. Default
shots = 1_000.
backend_parameters (Dict[str, Any]): Additional parameters to send to the device.
For example, for D-Wave:
`{"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}`
Returns:
AwsQuantumTask: AwsQuantumTask tracking the task execution on the device.
Note:
The following arguments are typically defined via clients of Device.
- `task_specification`
- `s3_destination_folder`
- `shots`
"""
if len(s3_destination_folder) != 2:
raise ValueError(
"s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively."
)
create_task_kwargs = _create_common_params(
device_arn,
s3_destination_folder,
shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS,
)
return _create_internal(
task_specification,
aws_session,
create_task_kwargs,
backend_parameters or {},
*args,
**kwargs,
)
def __init__(
self,
arn: str,
aws_session: AwsSession = None,
poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL,
logger: Logger = getLogger(__name__),
):
"""
Args:
arn (str): The ARN of the task.
aws_session (AwsSession, optional): The `AwsSession` for connecting to AWS services.
Default is `None`, in which case an `AwsSession` object will be created with the
region of the task.
poll_timeout_seconds (int): The polling timeout for result(), default is 120 seconds.
poll_interval_seconds (int): The polling interval for result(), default is 0.25
seconds.
logger (Logger): Logger object with which to write logs, such as task statuses
while waiting for task to be in a terminal state. Default is `getLogger(__name__)`
Examples:
>>> task = AwsQuantumTask(arn='task_arn')
>>> task.state()
'COMPLETED'
>>> result = task.result()
AnnealingQuantumTaskResult(...)
>>> task = AwsQuantumTask(arn='task_arn', poll_timeout_seconds=300)
>>> result = task.result()
GateModelQuantumTaskResult(...)
"""
self._arn: str = arn
self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn(
task_arn=arn
)
self._poll_timeout_seconds = poll_timeout_seconds
self._poll_interval_seconds = poll_interval_seconds
self._logger = logger
self._metadata: Dict[str, Any] = {}
self._result: Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult] = None
try:
asyncio.get_event_loop()
except Exception as e:
self._logger.debug(e)
self._logger.info("No event loop found; creating new event loop")
asyncio.set_event_loop(asyncio.new_event_loop())
self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
@staticmethod
def _aws_session_for_task_arn(task_arn: str) -> AwsSession:
"""
Get an AwsSession for the Task ARN. The AWS session should be in the region of the task.
Returns:
AwsSession: `AwsSession` object with default `boto_session` in task's region
"""
task_region = task_arn.split(":")[3]
boto_session = boto3.Session(region_name=task_region)
return AwsSession(boto_session=boto_session)
@property
def id(self) -> str:
"""str: The ARN of the quantum task."""
return self._arn
[docs] def cancel(self) -> None:
"""Cancel the quantum task. This cancels the future and the task in Amazon Braket."""
self._future.cancel()
self._aws_session.cancel_quantum_task(self._arn)
[docs] def state(self, use_cached_value: bool = False) -> str:
"""
The state of the quantum task.
Args:
use_cached_value (bool, optional): If `True`, uses the value most recently retrieved
from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the
`GetQuantumTask` operation to retrieve metadata, which also updates the cached
value. Default = `False`.
Returns:
str: The value of `status` in `metadata()`. This is the value of the `status` key
in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`,
the value most recently returned from the `GetQuantumTask` operation is used.
See Also:
`metadata()`
"""
return self.metadata(use_cached_value).get("status")
[docs] def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]:
"""
Get the quantum task result by polling Amazon Braket to see if the task is completed.
Once the task is completed, the result is retrieved from S3 and returned as a
`GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult`
This method is a blocking thread call and synchronously returns a result. Call
async_result() if you require an asynchronous invocation.
Consecutive calls to this method return a cached result from the preceding request.
"""
try:
return asyncio.get_event_loop().run_until_complete(self.async_result())
except asyncio.CancelledError:
# Future was cancelled, return whatever is in self._result if anything
self._logger.warning("Task future was cancelled")
return self._result
[docs] def async_result(self) -> asyncio.Task:
"""
Get the quantum task result asynchronously. Consecutive calls to this method return
the result cached from the most recent request.
"""
if self._future.done() and self._result is None: # timed out and no result
task_status = self.metadata()["status"]
if task_status in self.NO_RESULT_TERMINAL_STATES:
self._logger.warning(
f"Task is in terminal state {task_status} and no result is available"
)
# Return done future. Don't restart polling.
return self._future
else:
self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
return self._future
async def _create_future(self) -> asyncio.Task:
"""
Wrap the `_wait_for_completion` coroutine inside a future-like object.
Invoking this method starts the coroutine and returns back the future-like object
that contains it. Note that this does not block on the coroutine to finish.
Returns:
asyncio.Task: An asyncio Task that contains the _wait_for_completion() coroutine.
"""
return asyncio.create_task(self._wait_for_completion())
def _get_results_formatter(
self,
) -> Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]:
"""
Get results formatter based on irType of self.metadata()
Returns:
Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]:
function that deserializes a string into a results structure
"""
current_metadata = self.metadata()
ir_type = current_metadata["irType"]
if ir_type == AwsQuantumTask.ANNEALING_IR_TYPE:
return AnnealingQuantumTaskResult.from_string
elif ir_type == AwsQuantumTask.GATE_IR_TYPE:
return GateModelQuantumTaskResult.from_string
else:
raise ValueError("Unknown IR type")
async def _wait_for_completion(self) -> GateModelQuantumTaskResult:
"""
Waits for the quantum task to be completed, then returns the result from the S3 bucket.
Returns:
GateModelQuantumTaskResult: If the task is in the `AwsQuantumTask.RESULTS_READY_STATES`
state within the specified time limit, the result from the S3 bucket is loaded and
returned. `None` is returned if a timeout occurs or task state is in
`AwsQuantumTask.NO_RESULT_TERMINAL_STATES`.
Note:
Timeout and sleep intervals are defined in the constructor fields
`poll_timeout_seconds` and `poll_interval_seconds` respectively.
"""
self._logger.debug(f"Task {self._arn}: start polling for completion")
start_time = time.time()
while (time.time() - start_time) < self._poll_timeout_seconds:
current_metadata = self.metadata()
task_status = current_metadata["status"]
self._logger.debug(f"Task {self._arn}: task status {task_status}")
if task_status in AwsQuantumTask.RESULTS_READY_STATES:
result_string = self._aws_session.retrieve_s3_object_body(
current_metadata["resultsS3Bucket"], current_metadata["resultsS3ObjectKey"]
)
self._result = self._get_results_formatter()(result_string)
return self._result
elif task_status in AwsQuantumTask.NO_RESULT_TERMINAL_STATES:
self._logger.warning(
f"Task is in terminal state {task_status}" + "and no result is available"
)
self._result = None
return None
else:
await asyncio.sleep(self._poll_interval_seconds)
# Timed out
self._logger.warning(
f"Task {self._arn}: polling for task completion timed out after "
+ f"{time.time()-start_time} secs"
)
self._result = None
return None
def __repr__(self) -> str:
return f"AwsQuantumTask('id':{self.id})"
def __eq__(self, other) -> bool:
if isinstance(other, AwsQuantumTask):
return self.id == other.id
return NotImplemented
def __hash__(self) -> int:
return hash(self.id)
@singledispatch
def _create_internal(
task_specification: Union[Circuit, Problem],
aws_session: AwsSession,
create_task_kwargs: Dict[str, Any],
backend_parameters: Dict[str, Any],
*args,
**kwargs,
) -> AwsQuantumTask:
raise TypeError("Invalid task specification type")
@_create_internal.register
def _(
circuit: Circuit,
aws_session: AwsSession,
create_task_kwargs: Dict[str, Any],
backend_parameters: Dict[str, Any],
*args,
**kwargs,
) -> AwsQuantumTask:
create_task_kwargs.update(
{
"ir": circuit.to_ir().json(),
"irType": AwsQuantumTask.GATE_IR_TYPE,
"backendParameters": {"gateModelParameters": {"qubitCount": circuit.qubit_count}},
}
)
task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)
@_create_internal.register
def _(
problem: Problem,
aws_session: AwsSession,
create_task_kwargs: Dict[str, Any],
backend_parameters: Dict[str, Any],
*args,
**kwargs,
) -> AwsQuantumTask:
create_task_kwargs.update(
{
"ir": problem.to_ir().json(),
"irType": AwsQuantumTask.ANNEALING_IR_TYPE,
"backendParameters": {"annealingModelParameters": backend_parameters},
}
)
task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)
def _create_common_params(
device_arn: str, s3_destination_folder: AwsSession.S3DestinationFolder, shots: int
) -> Dict[str, Any]:
return {
"backendArn": device_arn,
"resultsS3Bucket": s3_destination_folder[0],
"resultsS3Prefix": s3_destination_folder[1],
"shots": shots,
}