# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import annotations
import asyncio
import time
from functools import singledispatch
from logging import Logger, getLogger
from typing import Any, Dict, Union
import boto3
from braket.annealing.problem import Problem
from braket.aws.aws_session import AwsSession
from braket.circuits.circuit import Circuit
from braket.circuits.circuit_helpers import validate_circuit_and_shots
from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult, QuantumTask
[docs]class AwsQuantumTask(QuantumTask):
"""Amazon Braket implementation of a quantum task. A task can be a circuit or an annealing
problem."""
# TODO: Add API documentation that defines these states. Make it clear this is the contract.
NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"}
RESULTS_READY_STATES = {"COMPLETED"}
GATE_IR_TYPE = "jaqcd"
ANNEALING_IR_TYPE = "annealing"
DEFAULT_RESULTS_POLL_TIMEOUT = 120
DEFAULT_RESULTS_POLL_INTERVAL = 0.25
[docs] @staticmethod
def create(
aws_session: AwsSession,
device_arn: str,
task_specification: Union[Circuit, Problem],
s3_destination_folder: AwsSession.S3DestinationFolder,
shots: int,
backend_parameters: Dict[str, Any] = None,
*args,
**kwargs,
) -> AwsQuantumTask:
"""AwsQuantumTask factory method that serializes a quantum task specification
(either a quantum circuit or annealing problem), submits it to Amazon Braket,
and returns back an AwsQuantumTask tracking the execution.
Args:
aws_session (AwsSession): AwsSession to connect to AWS with.
device_arn (str): The ARN of the quantum device.
task_specification (Union[Circuit, Problem]): The specification of the task
to run on device.
s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple, with bucket
for index 0 and key for index 1, that specifies the Amazon S3 bucket and folder
to store task results in.
shots (int): The number of times to run the task on the device. If the device is a
simulator, this implies the state is sampled N times, where N = `shots`.
`shots=0` is only available on simulators and means that the simulator
will compute the exact results based on the task specification.
backend_parameters (Dict[str, Any]): Additional parameters to send to the device.
For example, for D-Wave:
`{"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}`
Returns:
AwsQuantumTask: AwsQuantumTask tracking the task execution on the device.
Note:
The following arguments are typically defined via clients of Device.
- `task_specification`
- `s3_destination_folder`
- `shots`
See Also:
`braket.aws.aws_quantum_simulator.AwsQuantumSimulator.run()`
`braket.aws.aws_qpu.AwsQpu.run()`
"""
if len(s3_destination_folder) != 2:
raise ValueError(
"s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively."
)
create_task_kwargs = _create_common_params(
device_arn,
s3_destination_folder,
shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS,
)
return _create_internal(
task_specification,
aws_session,
create_task_kwargs,
backend_parameters or {},
*args,
**kwargs,
)
def __init__(
self,
arn: str,
aws_session: AwsSession = None,
poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL,
logger: Logger = getLogger(__name__),
):
"""
Args:
arn (str): The ARN of the task.
aws_session (AwsSession, optional): The `AwsSession` for connecting to AWS services.
Default is `None`, in which case an `AwsSession` object will be created with the
region of the task.
poll_timeout_seconds (int): The polling timeout for result(), default is 120 seconds.
poll_interval_seconds (int): The polling interval for result(), default is 0.25
seconds.
logger (Logger): Logger object with which to write logs, such as task statuses
while waiting for task to be in a terminal state. Default is `getLogger(__name__)`
Examples:
>>> task = AwsQuantumTask(arn='task_arn')
>>> task.state()
'COMPLETED'
>>> result = task.result()
AnnealingQuantumTaskResult(...)
>>> task = AwsQuantumTask(arn='task_arn', poll_timeout_seconds=300)
>>> result = task.result()
GateModelQuantumTaskResult(...)
"""
self._arn: str = arn
self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn(
task_arn=arn
)
self._poll_timeout_seconds = poll_timeout_seconds
self._poll_interval_seconds = poll_interval_seconds
self._logger = logger
self._metadata: Dict[str, Any] = {}
self._result: Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult] = None
try:
asyncio.get_event_loop()
except Exception as e:
self._logger.debug(e)
self._logger.info("No event loop found; creating new event loop")
asyncio.set_event_loop(asyncio.new_event_loop())
self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
@staticmethod
def _aws_session_for_task_arn(task_arn: str) -> AwsSession:
"""
Get an AwsSession for the Task ARN. The AWS session should be in the region of the task.
Returns:
AwsSession: `AwsSession` object with default `boto_session` in task's region
"""
task_region = task_arn.split(":")[3]
boto_session = boto3.Session(region_name=task_region)
return AwsSession(boto_session=boto_session)
@property
def id(self) -> str:
"""str: The ARN of the quantum task."""
return self._arn
[docs] def cancel(self) -> None:
"""Cancel the quantum task. This cancels the future and the task in Amazon Braket."""
self._future.cancel()
self._aws_session.cancel_quantum_task(self._arn)
[docs] def state(self, use_cached_value: bool = False) -> str:
"""
The state of the quantum task.
Args:
use_cached_value (bool, optional): If `True`, uses the value most recently retrieved
from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the
`GetQuantumTask` operation to retrieve metadata, which also updates the cached
value. Default = `False`.
Returns:
str: The value of `status` in `metadata()`. This is the value of the `status` key
in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`,
the value most recently returned from the `GetQuantumTask` operation is used.
See Also:
`metadata()`
"""
return self.metadata(use_cached_value).get("status")
[docs] def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]:
"""
Get the quantum task result by polling Amazon Braket to see if the task is completed.
Once the task is completed, the result is retrieved from S3 and returned as a
`GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult`
This method is a blocking thread call and synchronously returns a result. Call
async_result() if you require an asynchronous invocation.
Consecutive calls to this method return a cached result from the preceding request.
"""
try:
return asyncio.get_event_loop().run_until_complete(self.async_result())
except asyncio.CancelledError:
# Future was cancelled, return whatever is in self._result if anything
self._logger.warning("Task future was cancelled")
return self._result
[docs] def async_result(self) -> asyncio.Task:
"""
Get the quantum task result asynchronously. Consecutive calls to this method return
the result cached from the most recent request.
"""
if self._future.done() and self._result is None: # timed out and no result
task_status = self.metadata()["status"]
if task_status in self.NO_RESULT_TERMINAL_STATES:
self._logger.warning(
f"Task is in terminal state {task_status} and no result is available"
)
else:
self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
return self._future
async def _create_future(self) -> asyncio.Task:
"""
Wrap the `_wait_for_completion` coroutine inside a future-like object.
Invoking this method starts the coroutine and returns back the future-like object
that contains it. Note that this does not block on the coroutine to finish.
Returns:
asyncio.Task: An asyncio Task that contains the _wait_for_completion() coroutine.
"""
return asyncio.create_task(self._wait_for_completion())
def _get_results_formatter(
self,
) -> Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]:
"""
Get results formatter based on irType of self.metadata()
Returns:
Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]:
function that deserializes a string into a results structure
"""
current_metadata = self.metadata()
ir_type = current_metadata["irType"]
if ir_type == AwsQuantumTask.ANNEALING_IR_TYPE:
return AnnealingQuantumTaskResult.from_string
elif ir_type == AwsQuantumTask.GATE_IR_TYPE:
return GateModelQuantumTaskResult.from_string
else:
raise ValueError("Unknown IR type")
async def _wait_for_completion(
self,
) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]:
"""
Waits for the quantum task to be completed, then returns the result from the S3 bucket.
Returns:
Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: If the task is in the
`AwsQuantumTask.RESULTS_READY_STATES` state within the specified time limit,
the result from the S3 bucket is loaded and returned.
`None` is returned if a timeout occurs or task state is in
`AwsQuantumTask.NO_RESULT_TERMINAL_STATES`.
Note:
Timeout and sleep intervals are defined in the constructor fields
`poll_timeout_seconds` and `poll_interval_seconds` respectively.
"""
self._logger.debug(f"Task {self._arn}: start polling for completion")
start_time = time.time()
while (time.time() - start_time) < self._poll_timeout_seconds:
current_metadata = self.metadata()
task_status = current_metadata["status"]
self._logger.debug(f"Task {self._arn}: task status {task_status}")
if task_status in AwsQuantumTask.RESULTS_READY_STATES:
result_string = self._aws_session.retrieve_s3_object_body(
current_metadata["resultsS3Bucket"], current_metadata["resultsS3ObjectKey"]
)
self._result = self._get_results_formatter()(result_string)
return self._result
elif task_status in AwsQuantumTask.NO_RESULT_TERMINAL_STATES:
self._logger.warning(
f"Task is in terminal state {task_status} and no result is available"
)
self._result = None
return None
else:
await asyncio.sleep(self._poll_interval_seconds)
# Timed out
self._logger.warning(
f"Task {self._arn}: polling for task completion timed out after "
+ f"{time.time()-start_time} secs"
)
self._result = None
return None
def __repr__(self) -> str:
return f"AwsQuantumTask('id':{self.id})"
def __eq__(self, other) -> bool:
if isinstance(other, AwsQuantumTask):
return self.id == other.id
return NotImplemented
def __hash__(self) -> int:
return hash(self.id)
@singledispatch
def _create_internal(
task_specification: Union[Circuit, Problem],
aws_session: AwsSession,
create_task_kwargs: Dict[str, Any],
backend_parameters: Dict[str, Any],
*args,
**kwargs,
) -> AwsQuantumTask:
raise TypeError("Invalid task specification type")
@_create_internal.register
def _(
circuit: Circuit,
aws_session: AwsSession,
create_task_kwargs: Dict[str, Any],
backend_parameters: Dict[str, Any],
*args,
**kwargs,
) -> AwsQuantumTask:
validate_circuit_and_shots(circuit, create_task_kwargs["shots"])
create_task_kwargs.update(
{
"ir": circuit.to_ir().json(),
"irType": AwsQuantumTask.GATE_IR_TYPE,
"backendParameters": {"gateModelParameters": {"qubitCount": circuit.qubit_count}},
}
)
task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)
@_create_internal.register
def _(
problem: Problem,
aws_session: AwsSession,
create_task_kwargs: Dict[str, Any],
backend_parameters: Dict[str, Any],
*args,
**kwargs,
) -> AwsQuantumTask:
create_task_kwargs.update(
{
"ir": problem.to_ir().json(),
"irType": AwsQuantumTask.ANNEALING_IR_TYPE,
"backendParameters": {"annealingModelParameters": backend_parameters},
}
)
task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)
def _create_common_params(
device_arn: str, s3_destination_folder: AwsSession.S3DestinationFolder, shots: int
) -> Dict[str, Any]:
return {
"backendArn": device_arn,
"resultsS3Bucket": s3_destination_folder[0],
"resultsS3Prefix": s3_destination_folder[1],
"shots": shots,
}