Skip to content

Commit

Permalink
Merge pull request #66 from aws/hotfix/0.3.4
Browse files Browse the repository at this point in the history
Hotfix/0.3.4
  • Loading branch information
avawang1 authored Apr 11, 2020
2 parents a35dc5f + 41b0ae4 commit 32bf930
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 68 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="braket-sdk",
version="0.3.2",
version="0.3.4",
license="Apache License 2.0",
python_requires=">= 3.7.2",
packages=find_namespace_packages(where="src", exclude=("test",)),
Expand Down
119 changes: 83 additions & 36 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import time
from functools import singledispatch
from logging import Logger, getLogger
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Dict, Optional, Union

import boto3
from braket.annealing.problem import Problem
from braket.aws.aws_session import AwsSession
from braket.circuits.circuit import Circuit
Expand All @@ -30,7 +31,7 @@ class AwsQuantumTask(QuantumTask):
problem."""

# TODO: Add API documentation that defines these states. Make it clear this is the contract.
TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELLED"}
NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"}
RESULTS_READY_STATES = {"COMPLETED"}

GATE_IR_TYPE = "jaqcd"
Expand Down Expand Up @@ -72,7 +73,7 @@ def create(
backend_parameters (Dict[str, Any]): Additional parameters to send to the device.
For example, for D-Wave:
>>> backend_parameters = {"dWaveParameters": {"postprocess": "OPTIMIZATION"}}
`{"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}`
Returns:
AwsQuantumTask: AwsQuantumTask tracking the task execution on the device.
Expand Down Expand Up @@ -104,28 +105,39 @@ def create(
def __init__(
self,
arn: str,
aws_session: AwsSession,
results_formatter: Callable[[str], Any],
aws_session: AwsSession = None,
poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL,
logger: Logger = getLogger(__name__),
):
"""
Args:
arn (str): The ARN of the task.
aws_session (AwsSession): The `AwsSession` for connecting to AWS services.
results_formatter (Callable[[str], Any]): A function that deserializes a string
into a results structure (such as `GateModelQuantumTaskResult`)
aws_session (AwsSession, optional): The `AwsSession` for connecting to AWS services.
Default is `None`, in which case an `AwsSession` object will be created with the
region of the task.
poll_timeout_seconds (int): The polling timeout for result(), default is 120 seconds.
poll_interval_seconds (int): The polling interval for result(), default is 0.25
seconds.
logger (Logger): Logger object with which to write logs, such as task statuses
while waiting for task to be in a terminal state. Default is `getLogger(__name__)`
Examples:
>>> task = AwsQuantumTask(arn='task_arn')
>>> task.state()
'COMPLETED'
>>> result = task.result()
AnnealingQuantumTaskResult(...)
>>> task = AwsQuantumTask(arn='task_arn', poll_timeout_seconds=300)
>>> result = task.result()
GateModelQuantumTaskResult(...)
"""

self._arn: str = arn
self._aws_session: AwsSession = aws_session
self._results_formatter = results_formatter
self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn(
task_arn=arn
)
self._poll_timeout_seconds = poll_timeout_seconds
self._poll_interval_seconds = poll_interval_seconds
self._logger = logger
Expand All @@ -140,6 +152,18 @@ def __init__(
asyncio.set_event_loop(asyncio.new_event_loop())
self._future = asyncio.get_event_loop().run_until_complete(self._create_future())

@staticmethod
def _aws_session_for_task_arn(task_arn: str) -> AwsSession:
"""
Get an AwsSession for the Task ARN. The AWS session should be in the region of the task.
Returns:
AwsSession: `AwsSession` object with default `boto_session` in task's region
"""
task_region = task_arn.split(":")[3]
boto_session = boto3.Session(region_name=task_region)
return AwsSession(boto_session=boto_session)

@property
def id(self) -> str:
"""str: The ARN of the quantum task."""
Expand All @@ -158,7 +182,7 @@ def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]:
use_cached_value (bool, optional): If `True`, uses the value most recently retrieved
from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the
`GetQuantumTask` operation to retrieve metadata, which also updates the cached
value. Default = False.
value. Default = `False`.
Returns:
Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation.
If `use_cached_value` is `True`, Amazon Braket is not called and the most recently
Expand All @@ -176,7 +200,7 @@ def state(self, use_cached_value: bool = False) -> str:
use_cached_value (bool, optional): If `True`, uses the value most recently retrieved
from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the
`GetQuantumTask` operation to retrieve metadata, which also updates the cached
value. Default = False.
value. Default = `False`.
Returns:
str: The value of `status` in `metadata()`. This is the value of the `status` key
in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`,
Expand All @@ -190,7 +214,7 @@ def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult
"""
Get the quantum task result by polling Amazon Braket to see if the task is completed.
Once the task is completed, the result is retrieved from S3 and returned as a
`QuantumTaskResult`.
`GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult`
This method is a blocking thread call and synchronously returns a result. Call
async_result() if you require an asynchronous invocation.
Expand All @@ -200,19 +224,22 @@ def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult
return asyncio.get_event_loop().run_until_complete(self.async_result())
except asyncio.CancelledError:
# Future was cancelled, return whatever is in self._result if anything
self._logger.warning("Task future was cancelled")
return self._result

def async_result(self) -> asyncio.Task:
"""
Get the quantum task result asynchronously. Consecutive calls to this method return
the result cached from the most recent request.
"""
if (
self._future.done()
and self.metadata(use_cached_value=True).get("status")
not in AwsQuantumTask.TERMINAL_STATES
): # Future timed out
self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
if self._future.done() and self._result is None: # timed out and no result
task_status = self.metadata()["status"]
if task_status in self.NO_RESULT_TERMINAL_STATES:
self._logger.warning(
f"Task is in terminal state {task_status} and no result is available"
)
else:
self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
return self._future

async def _create_future(self) -> asyncio.Task:
Expand All @@ -226,17 +253,37 @@ async def _create_future(self) -> asyncio.Task:
"""
return asyncio.create_task(self._wait_for_completion())

def _get_results_formatter(
self,
) -> Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]:
"""
Get results formatter based on irType of self.metadata()
Returns:
Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]:
function that deserializes a string into a results structure
"""
current_metadata = self.metadata()
ir_type = current_metadata["irType"]
if ir_type == AwsQuantumTask.ANNEALING_IR_TYPE:
return AnnealingQuantumTaskResult.from_string
elif ir_type == AwsQuantumTask.GATE_IR_TYPE:
return GateModelQuantumTaskResult.from_string
else:
raise ValueError("Unknown IR type")

async def _wait_for_completion(
self,
) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]:
"""
Waits for the quantum task to be completed, then returns the result from the S3 bucket.
Returns:
Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: If the task is in
the `AwsQuantumTask.RESULTS_READY_STATES` state within the specified time limit,
the result from the S3 bucket is loaded and returned. `None` is returned if a
timeout occurs or task state is in `AwsQuantumTask.TERMINAL_STATES` but not
`AwsQuantumTask.RESULTS_READY_STATES`.
Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: If the task is in the
`AwsQuantumTask.RESULTS_READY_STATES` state within the specified time limit,
the result from the S3 bucket is loaded and returned.
`None` is returned if a timeout occurs or task state is in
`AwsQuantumTask.NO_RESULT_TERMINAL_STATES`.
Note:
Timeout and sleep intervals are defined in the constructor fields
`poll_timeout_seconds` and `poll_interval_seconds` respectively.
Expand All @@ -246,22 +293,27 @@ async def _wait_for_completion(

while (time.time() - start_time) < self._poll_timeout_seconds:
current_metadata = self.metadata()
self._logger.debug(f"Task {self._arn}: task status {current_metadata['status']}")
if current_metadata["status"] in AwsQuantumTask.RESULTS_READY_STATES:
task_status = current_metadata["status"]
self._logger.debug(f"Task {self._arn}: task status {task_status}")
if task_status in AwsQuantumTask.RESULTS_READY_STATES:
result_string = self._aws_session.retrieve_s3_object_body(
current_metadata["resultsS3Bucket"], current_metadata["resultsS3ObjectKey"]
)
self._result = self._results_formatter(result_string)
self._result = self._get_results_formatter()(result_string)
return self._result
elif current_metadata["status"] in AwsQuantumTask.TERMINAL_STATES:
elif task_status in AwsQuantumTask.NO_RESULT_TERMINAL_STATES:
self._logger.warning(
f"Task is in terminal state {task_status} and no result is available"
)
self._result = None
return None
else:
await asyncio.sleep(self._poll_interval_seconds)

# Timed out
self._logger.warning(
f"Task {self._arn}: polling timed out after {time.time()-start_time} secs"
f"Task {self._arn}: polling for task completion timed out after "
+ f"{time.time()-start_time} secs"
)
self._result = None
return None
Expand Down Expand Up @@ -306,11 +358,8 @@ def _(
"backendParameters": {"gateModelParameters": {"qubitCount": circuit.qubit_count}},
}
)

task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(
task_arn, aws_session, GateModelQuantumTaskResult.from_string, *args, **kwargs
)
return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)


@_create_internal.register
Expand All @@ -331,9 +380,7 @@ def _(
)

task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(
task_arn, aws_session, AnnealingQuantumTaskResult.from_string, *args, **kwargs
)
return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)


def _create_common_params(
Expand Down
Loading

0 comments on commit 32bf930

Please sign in to comment.