Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly cache status #174

Merged
merged 2 commits into from
Nov 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,16 @@ def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]:

Args:
use_cached_value (bool, optional): If `True`, uses the value most recently retrieved
from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the
`GetQuantumTask` operation to retrieve metadata, which also updates the cached
value. Default = `False`.
from the Amazon Braket `GetQuantumTask` operation, if it exists; if not,
`GetQuantumTask` will be called to retrieve the metadata. If `False`, always calls
`GetQuantumTask`, which also updates the cached value. Default: `False`.
Returns:
Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation.
If `use_cached_value` is `True`, Amazon Braket is not called and the most recently
retrieved value is used.
retrieved value is used, unless `GetQuantumTask` was never called, in which case
it wil still be called to populate the metadata for the first time.
"""
if not use_cached_value:
if not use_cached_value or not self._metadata:
self._metadata = self._aws_session.get_quantum_task(self._arn)
return self._metadata

Expand Down Expand Up @@ -255,6 +256,13 @@ def _status(self, use_cached_value=False):
self._logger.warning(f"Task is in terminal state {status} and no result is available")
return status

def _update_status_if_nonterminal(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still not ideal: consider the case where the metadata is empty and the status is "RUNNING".
Then first self._status(True) is called, filling the metadata and returning the "RUNNING" status.
Since it's not a terminal state, self._status(False) will be called again. This uses two GetQuantumTask calls.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, there will be at most one extra call to GetQuantumTask, and since the status isn't terminal, I don't think it hurts to make the extra call; it's going to have to poll again regardless. Easy enough to get rid of it though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the scenario I described it's two successive calls to GetQuantumTask, so the second call will likely return the same result and be wasted. In any case, thanks for the fix.

# If metadata has not been populated, the first call to _status will fetch it,
# so the second _status call will no longer need to
metadata_absent = self._metadata is None
cached = self._status(True)
return cached if cached in self.TERMINAL_STATES else self._status(metadata_absent)

def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]:
"""
Get the quantum task result by polling Amazon Braket to see if the task is completed.
Expand Down Expand Up @@ -293,7 +301,8 @@ def _get_future(self):
self._future.done()
and not self._future.cancelled()
and self._result is None
and self._status() not in self.NO_RESULT_TERMINAL_STATES # timed out and no result
# timed out and no result
and self._update_status_if_nonterminal() not in self.NO_RESULT_TERMINAL_STATES
):
self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
return self._future
Expand Down Expand Up @@ -349,10 +358,8 @@ async def _wait_for_completion(
)
continue
# Used cached metadata if cached status is terminal
current_metadata = self.metadata(
self._status(False) not in AwsQuantumTask.TERMINAL_STATES
)
task_status = self._status(False)
task_status = self._update_status_if_nonterminal()
current_metadata = self.metadata(True)
self._logger.debug(f"Task {self._arn}: task status {task_status}")
if task_status in AwsQuantumTask.RESULTS_READY_STATES:
result_string = self._aws_session.retrieve_s3_object_body(
Expand Down
17 changes: 15 additions & 2 deletions test/unit_tests/braket/aws/test_aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def test_is_polling_time(window, cur_time, expected_val, quantum_task):


def test_result_not_polling(quantum_task):
quantum_task._metadata = {"a": 0}
quantum_task._poll_outside_execution_window = False
quantum_task._poll_timeout_seconds = 0.01
window = {
Expand All @@ -258,6 +259,13 @@ def test_metadata(quantum_task):
assert quantum_task.metadata(use_cached_value=True) == metadata_1


def test_metadata_call_if_none(quantum_task):
metadata_1 = {"status": "RUNNING"}
quantum_task._aws_session.get_quantum_task.return_value = metadata_1
assert quantum_task.metadata(use_cached_value=True) == metadata_1
quantum_task._aws_session.get_quantum_task.assert_called_with(quantum_task.id)


def test_state(quantum_task):
state_1 = "RUNNING"
_mock_metadata(quantum_task._aws_session, state_1)
Expand Down Expand Up @@ -421,6 +429,12 @@ def test_timeout_completed(aws_session):
assert quantum_task.result() == GateModelQuantumTaskResult.from_string(
MockS3.MOCK_S3_RESULT_GATE_MODEL
)
# Cached status is still COMPLETED, so result should be fetched
_mock_metadata(aws_session, "RUNNING")
quantum_task._result = None
assert quantum_task.result() == GateModelQuantumTaskResult.from_string(
MockS3.MOCK_S3_RESULT_GATE_MODEL
)


def test_timeout_no_result_terminal_state(aws_session):
Expand Down Expand Up @@ -632,12 +646,11 @@ def _assert_create_quantum_task_called_with(


def _mock_metadata(aws_session, state):
return_value = {
aws_session.get_quantum_task.return_value = {
"status": state,
"outputS3Bucket": S3_TARGET.bucket,
"outputS3Directory": S3_TARGET.key,
}
aws_session.get_quantum_task.return_value = return_value


def _mock_s3(aws_session, result):
Expand Down