Skip to content

Commit

Permalink
fix: return empty_dict for job.result() when results not saved (#82)
Browse files Browse the repository at this point in the history
* fix: return empty_dict when job results not saved

* Handle non-404 errors

* check download_from_s3 instead of download_result

* Add check if results not available in json format in tar

* Return empty_dict if result.json not found

* address comment
  • Loading branch information
virajvchaudhari authored Sep 29, 2021
1 parent a5f3406 commit 10beeb3
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import Any, Dict, List

import boto3
from botocore.exceptions import ClientError

from braket.aws.aws_session import AwsSession
from braket.jobs import logs
Expand Down Expand Up @@ -495,13 +496,27 @@ def result(

with tempfile.TemporaryDirectory() as temp_dir:
job_name = self.metadata(True)["jobName"]
self.download_result(temp_dir, poll_timeout_seconds, poll_interval_seconds)

try:
self.download_result(temp_dir, poll_timeout_seconds, poll_interval_seconds)
except ClientError as e:
if e.response["Error"]["Code"] == "404":
return {}
else:
raise e
return AwsQuantumJob._read_and_deserialize_results(temp_dir, job_name)

@staticmethod
def _read_and_deserialize_results(temp_dir, job_name):
try:
with open(f"{temp_dir}/{job_name}/{AwsQuantumJob.RESULTS_FILENAME}", "r") as f:
persisted_data = PersistedJobData.parse_raw(f.read())
deserialized_data = deserialize_values(
persisted_data.dataDictionary, persisted_data.dataFormat
)
return deserialized_data
except FileNotFoundError:
return {}

def download_result(
self,
Expand Down
51 changes: 51 additions & 0 deletions test/unit_tests/braket/aws/test_aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,57 @@ def test_download_result_when_extract_path_provided(
assert expected_saved_data == actual_data


def test_empty_dict_returned_when_result_not_saved(
quantum_job, generate_get_job_response, aws_session
):
state = "COMPLETED"
get_job_response_completed = generate_get_job_response(status=state)
aws_session.get_job.return_value = get_job_response_completed

exception_response = {
"Error": {
"Code": "404",
"Message": "Not Found",
}
}
quantum_job._aws_session.download_from_s3 = Mock(
side_effect=ClientError(exception_response, "HeadObject")
)
assert quantum_job.result() == {}


def test_results_raises_error_for_non_404_errors(
quantum_job, generate_get_job_response, aws_session
):
state = "COMPLETED"
get_job_response_completed = generate_get_job_response(status=state)
aws_session.get_job.return_value = get_job_response_completed

error = "An error occurred \\(402\\) when calling the SomeObject operation: Something"

exception_response = {
"Error": {
"Code": "402",
"Message": "Something",
}
}
quantum_job._aws_session.download_from_s3 = Mock(
side_effect=ClientError(exception_response, "SomeObject")
)
with pytest.raises(ClientError, match=error):
quantum_job.result()


@patch("braket.aws.aws_quantum_job.AwsQuantumJob.download_result")
def test_results_json_file_not_in_tar(
result_download, quantum_job, aws_session, generate_get_job_response
):
state = "COMPLETED"
get_job_response_completed = generate_get_job_response(status=state)
quantum_job._aws_session.get_job.return_value = get_job_response_completed
assert quantum_job.result() == {}


@pytest.fixture
def entry_point():
return "test-source-dir.entry_point:func"
Expand Down

0 comments on commit 10beeb3

Please sign in to comment.