Source code for braket.aws.aws_quantum_task

# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import annotations

import asyncio
import time
from functools import singledispatch
from logging import Logger, getLogger
from typing import Any, Dict, Union

import boto3
from braket.annealing.problem import Problem
from braket.aws.aws_session import AwsSession
from braket.circuits.circuit import Circuit
from braket.circuits.circuit_helpers import validate_circuit_and_shots
from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult, QuantumTask


[docs]class AwsQuantumTask(QuantumTask): """Amazon Braket implementation of a quantum task. A task can be a circuit or an annealing problem.""" # TODO: Add API documentation that defines these states. Make it clear this is the contract. NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"} RESULTS_READY_STATES = {"COMPLETED"} GATE_IR_TYPE = "jaqcd" ANNEALING_IR_TYPE = "annealing" DEFAULT_RESULTS_POLL_TIMEOUT = 120 DEFAULT_RESULTS_POLL_INTERVAL = 0.25
[docs] @staticmethod def create( aws_session: AwsSession, device_arn: str, task_specification: Union[Circuit, Problem], s3_destination_folder: AwsSession.S3DestinationFolder, shots: int, backend_parameters: Dict[str, Any] = None, *args, **kwargs, ) -> AwsQuantumTask: """AwsQuantumTask factory method that serializes a quantum task specification (either a quantum circuit or annealing problem), submits it to Amazon Braket, and returns back an AwsQuantumTask tracking the execution. Args: aws_session (AwsSession): AwsSession to connect to AWS with. device_arn (str): The ARN of the quantum device. task_specification (Union[Circuit, Problem]): The specification of the task to run on device. s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple, with bucket for index 0 and key for index 1, that specifies the Amazon S3 bucket and folder to store task results in. shots (int): The number of times to run the task on the device. If the device is a simulator, this implies the state is sampled N times, where N = `shots`. `shots=0` is only available on simulators and means that the simulator will compute the exact results based on the task specification. backend_parameters (Dict[str, Any]): Additional parameters to send to the device. For example, for D-Wave: `{"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}` Returns: AwsQuantumTask: AwsQuantumTask tracking the task execution on the device. Note: The following arguments are typically defined via clients of Device. - `task_specification` - `s3_destination_folder` - `shots` See Also: `braket.aws.aws_quantum_simulator.AwsQuantumSimulator.run()` `braket.aws.aws_qpu.AwsQpu.run()` """ if len(s3_destination_folder) != 2: raise ValueError( "s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively." ) create_task_kwargs = _create_common_params( device_arn, s3_destination_folder, shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS, ) return _create_internal( task_specification, aws_session, create_task_kwargs, backend_parameters or {}, *args, **kwargs, )
def __init__( self, arn: str, aws_session: AwsSession = None, poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL, logger: Logger = getLogger(__name__), ): """ Args: arn (str): The ARN of the task. aws_session (AwsSession, optional): The `AwsSession` for connecting to AWS services. Default is `None`, in which case an `AwsSession` object will be created with the region of the task. poll_timeout_seconds (int): The polling timeout for result(), default is 120 seconds. poll_interval_seconds (int): The polling interval for result(), default is 0.25 seconds. logger (Logger): Logger object with which to write logs, such as task statuses while waiting for task to be in a terminal state. Default is `getLogger(__name__)` Examples: >>> task = AwsQuantumTask(arn='task_arn') >>> task.state() 'COMPLETED' >>> result = task.result() AnnealingQuantumTaskResult(...) >>> task = AwsQuantumTask(arn='task_arn', poll_timeout_seconds=300) >>> result = task.result() GateModelQuantumTaskResult(...) """ self._arn: str = arn self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn( task_arn=arn ) self._poll_timeout_seconds = poll_timeout_seconds self._poll_interval_seconds = poll_interval_seconds self._logger = logger self._metadata: Dict[str, Any] = {} self._result: Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult] = None try: asyncio.get_event_loop() except Exception as e: self._logger.debug(e) self._logger.info("No event loop found; creating new event loop") asyncio.set_event_loop(asyncio.new_event_loop()) self._future = asyncio.get_event_loop().run_until_complete(self._create_future()) @staticmethod def _aws_session_for_task_arn(task_arn: str) -> AwsSession: """ Get an AwsSession for the Task ARN. The AWS session should be in the region of the task. Returns: AwsSession: `AwsSession` object with default `boto_session` in task's region """ task_region = task_arn.split(":")[3] boto_session = boto3.Session(region_name=task_region) return AwsSession(boto_session=boto_session) @property def id(self) -> str: """str: The ARN of the quantum task.""" return self._arn
[docs] def cancel(self) -> None: """Cancel the quantum task. This cancels the future and the task in Amazon Braket.""" self._future.cancel() self._aws_session.cancel_quantum_task(self._arn)
[docs] def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]: """ Get task metadata defined in Amazon Braket. Args: use_cached_value (bool, optional): If `True`, uses the value most recently retrieved from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the `GetQuantumTask` operation to retrieve metadata, which also updates the cached value. Default = `False`. Returns: Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`, Amazon Braket is not called and the most recently retrieved value is used. """ if not use_cached_value: self._metadata = self._aws_session.get_quantum_task(self._arn) return self._metadata
[docs] def state(self, use_cached_value: bool = False) -> str: """ The state of the quantum task. Args: use_cached_value (bool, optional): If `True`, uses the value most recently retrieved from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the `GetQuantumTask` operation to retrieve metadata, which also updates the cached value. Default = `False`. Returns: str: The value of `status` in `metadata()`. This is the value of the `status` key in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`, the value most recently returned from the `GetQuantumTask` operation is used. See Also: `metadata()` """ return self.metadata(use_cached_value).get("status")
[docs] def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: """ Get the quantum task result by polling Amazon Braket to see if the task is completed. Once the task is completed, the result is retrieved from S3 and returned as a `GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult` This method is a blocking thread call and synchronously returns a result. Call async_result() if you require an asynchronous invocation. Consecutive calls to this method return a cached result from the preceding request. """ try: return asyncio.get_event_loop().run_until_complete(self.async_result()) except asyncio.CancelledError: # Future was cancelled, return whatever is in self._result if anything self._logger.warning("Task future was cancelled") return self._result
[docs] def async_result(self) -> asyncio.Task: """ Get the quantum task result asynchronously. Consecutive calls to this method return the result cached from the most recent request. """ if self._future.done() and self._result is None: # timed out and no result task_status = self.metadata()["status"] if task_status in self.NO_RESULT_TERMINAL_STATES: self._logger.warning( f"Task is in terminal state {task_status} and no result is available" ) else: self._future = asyncio.get_event_loop().run_until_complete(self._create_future()) return self._future
async def _create_future(self) -> asyncio.Task: """ Wrap the `_wait_for_completion` coroutine inside a future-like object. Invoking this method starts the coroutine and returns back the future-like object that contains it. Note that this does not block on the coroutine to finish. Returns: asyncio.Task: An asyncio Task that contains the _wait_for_completion() coroutine. """ return asyncio.create_task(self._wait_for_completion()) def _get_results_formatter( self, ) -> Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]: """ Get results formatter based on irType of self.metadata() Returns: Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]: function that deserializes a string into a results structure """ current_metadata = self.metadata() ir_type = current_metadata["irType"] if ir_type == AwsQuantumTask.ANNEALING_IR_TYPE: return AnnealingQuantumTaskResult.from_string elif ir_type == AwsQuantumTask.GATE_IR_TYPE: return GateModelQuantumTaskResult.from_string else: raise ValueError("Unknown IR type") async def _wait_for_completion( self, ) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: """ Waits for the quantum task to be completed, then returns the result from the S3 bucket. Returns: Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: If the task is in the `AwsQuantumTask.RESULTS_READY_STATES` state within the specified time limit, the result from the S3 bucket is loaded and returned. `None` is returned if a timeout occurs or task state is in `AwsQuantumTask.NO_RESULT_TERMINAL_STATES`. Note: Timeout and sleep intervals are defined in the constructor fields `poll_timeout_seconds` and `poll_interval_seconds` respectively. """ self._logger.debug(f"Task {self._arn}: start polling for completion") start_time = time.time() while (time.time() - start_time) < self._poll_timeout_seconds: current_metadata = self.metadata() task_status = current_metadata["status"] self._logger.debug(f"Task {self._arn}: task status {task_status}") if task_status in AwsQuantumTask.RESULTS_READY_STATES: result_string = self._aws_session.retrieve_s3_object_body( current_metadata["resultsS3Bucket"], current_metadata["resultsS3ObjectKey"] ) self._result = self._get_results_formatter()(result_string) return self._result elif task_status in AwsQuantumTask.NO_RESULT_TERMINAL_STATES: self._logger.warning( f"Task is in terminal state {task_status} and no result is available" ) self._result = None return None else: await asyncio.sleep(self._poll_interval_seconds) # Timed out self._logger.warning( f"Task {self._arn}: polling for task completion timed out after " + f"{time.time()-start_time} secs" ) self._result = None return None def __repr__(self) -> str: return f"AwsQuantumTask('id':{self.id})" def __eq__(self, other) -> bool: if isinstance(other, AwsQuantumTask): return self.id == other.id return NotImplemented def __hash__(self) -> int: return hash(self.id)
@singledispatch def _create_internal( task_specification: Union[Circuit, Problem], aws_session: AwsSession, create_task_kwargs: Dict[str, Any], backend_parameters: Dict[str, Any], *args, **kwargs, ) -> AwsQuantumTask: raise TypeError("Invalid task specification type") @_create_internal.register def _( circuit: Circuit, aws_session: AwsSession, create_task_kwargs: Dict[str, Any], backend_parameters: Dict[str, Any], *args, **kwargs, ) -> AwsQuantumTask: validate_circuit_and_shots(circuit, create_task_kwargs["shots"]) create_task_kwargs.update( { "ir": circuit.to_ir().json(), "irType": AwsQuantumTask.GATE_IR_TYPE, "backendParameters": {"gateModelParameters": {"qubitCount": circuit.qubit_count}}, } ) task_arn = aws_session.create_quantum_task(**create_task_kwargs) return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) @_create_internal.register def _( problem: Problem, aws_session: AwsSession, create_task_kwargs: Dict[str, Any], backend_parameters: Dict[str, Any], *args, **kwargs, ) -> AwsQuantumTask: create_task_kwargs.update( { "ir": problem.to_ir().json(), "irType": AwsQuantumTask.ANNEALING_IR_TYPE, "backendParameters": {"annealingModelParameters": backend_parameters}, } ) task_arn = aws_session.create_quantum_task(**create_task_kwargs) return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) def _create_common_params( device_arn: str, s3_destination_folder: AwsSession.S3DestinationFolder, shots: int ) -> Dict[str, Any]: return { "backendArn": device_arn, "resultsS3Bucket": s3_destination_folder[0], "resultsS3Prefix": s3_destination_folder[1], "shots": shots, }