Skip to content

Commit

Permalink
feat: wrap non-dict results and update results on subsequent calls (#721
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ajberdy authored Oct 4, 2023
1 parent 82ced09 commit 5c1e4e7
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 30 deletions.
13 changes: 2 additions & 11 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@
S3DataSourceConfig,
StoppingCondition,
)
from braket.jobs.data_persistence import load_job_result
from braket.jobs.metrics_data.cwl_insights_metrics_fetcher import CwlInsightsMetricsFetcher

# TODO: Have added metric file in metrics folder, but have to decide on the name for keep
# for the files, since all those metrics are retrieved from the CW.
from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType
from braket.jobs.quantum_job import QuantumJob
from braket.jobs.quantum_job_creation import prepare_quantum_job
from braket.jobs.serialization import deserialize_values
from braket.jobs_data import PersistedJobData


class AwsQuantumJob(QuantumJob):
Expand Down Expand Up @@ -482,15 +481,7 @@ def result(

@staticmethod
def _read_and_deserialize_results(temp_dir: str, job_name: str) -> Dict[str, Any]:
try:
with open(f"{temp_dir}/{job_name}/{AwsQuantumJob.RESULTS_FILENAME}", "r") as f:
persisted_data = PersistedJobData.parse_raw(f.read())
deserialized_data = deserialize_values(
persisted_data.dataDictionary, persisted_data.dataFormat
)
return deserialized_data
except FileNotFoundError:
return {}
return load_job_result(Path(temp_dir, job_name, AwsQuantumJob.RESULTS_FILENAME))

def download_result(
self,
Expand Down
80 changes: 65 additions & 15 deletions src/braket/jobs/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from typing import Any, Dict
from pathlib import Path
from typing import Any, Dict, Union

from braket.jobs.environment_variables import get_checkpoint_dir, get_job_name, get_results_dir
from braket.jobs.serialization import deserialize_values, serialize_values
Expand Down Expand Up @@ -104,33 +104,83 @@ def load_job_checkpoint(job_name: str, checkpoint_file_suffix: str = "") -> Dict
return deserialized_data


def _load_persisted_data(filename: Union[str, Path] = None) -> PersistedJobData:
filename = filename or Path(get_results_dir()) / "results.json"
try:
with open(filename, mode="r") as f:
return PersistedJobData.parse_raw(f.read())
except FileNotFoundError:
return PersistedJobData(
dataDictionary={},
dataFormat=PersistedJobDataFormat.PLAINTEXT,
)


def load_job_result(filename: Union[str, Path] = None) -> Dict[str, Any]:
"""
Loads job result of currently running job.
Args:
filename (Union[str, Path]): Location of job results. Default `results.json` in job
results directory in a job instance or in working directory locally. This file
must be in the format used by `save_job_result`.
Returns:
Dict[str, Any]: Job result data of current job
"""
persisted_data = _load_persisted_data(filename)
deserialized_data = deserialize_values(persisted_data.dataDictionary, persisted_data.dataFormat)
return deserialized_data


def save_job_result(
result_data: Dict[str, Any],
data_format: PersistedJobDataFormat = PersistedJobDataFormat.PLAINTEXT,
result_data: Union[Dict[str, Any], Any],
data_format: PersistedJobDataFormat = None,
) -> None:
"""
Saves the `result_data` to the local output directory that is specified by the container
environment variable `AMZN_BRAKET_JOB_RESULTS_DIR`, with the filename 'results.json'.
The `result_data` values are serialized to the specified `data_format`.
Note: This function for storing the results is only for use inside the hybrid job container
Note: This function for storing the results is only for use inside the job container
as it writes data to directories and references env variables set in the containers.
Args:
result_data (Dict[str, Any]): Dict that specifies the result data to be persisted.
result_data (Union[Dict[str, Any], Any]): Dict that specifies the result data to be
persisted. If result data is not a dict, then it will be wrapped as
`{"result": result_data}`.
data_format (PersistedJobDataFormat): The data format used to serialize the
values. Note that for `PICKLED` data formats, the values are base64 encoded
after serialization. Default: PersistedJobDataFormat.PLAINTEXT.
Raises:
ValueError: If the supplied `result_data` is `None` or empty.
"""
if not result_data:
raise ValueError("The result_data argument cannot be empty.")
result_directory = get_results_dir()
result_path = f"{result_directory}/results.json"
with open(result_path, "w") as f:
serialized_data = serialize_values(result_data or {}, data_format)
if not isinstance(result_data, dict):
result_data = {"result": result_data}

current_persisted_data = _load_persisted_data()

if current_persisted_data.dataFormat == PersistedJobDataFormat.PICKLED_V4:
# if results are already pickled, maintain pickled format
# if user explicitly specifies plaintext, raise error
if data_format == PersistedJobDataFormat.PLAINTEXT:
raise TypeError(
"Cannot update results object serialized with "
f"{current_persisted_data.dataFormat.value} using data format "
f"{data_format.value}."
)

data_format = PersistedJobDataFormat.PICKLED_V4

# if not specified or already pickled, default to plaintext
data_format = data_format or PersistedJobDataFormat.PLAINTEXT

current_results = deserialize_values(
current_persisted_data.dataDictionary,
current_persisted_data.dataFormat,
)
updated_results = {**current_results, **result_data}

with open(Path(get_results_dir()) / "results.json", "w") as f:
serialized_data = serialize_values(updated_results or {}, data_format)
persisted_data = PersistedJobData(dataDictionary=serialized_data, dataFormat=data_format)
f.write(persisted_data.json())
72 changes: 68 additions & 4 deletions test/unit_tests/braket/jobs/test_data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
import numpy as np
import pytest

from braket.jobs.data_persistence import load_job_checkpoint, save_job_checkpoint, save_job_result
from braket.jobs.data_persistence import (
load_job_checkpoint,
load_job_result,
save_job_checkpoint,
save_job_result,
)
from braket.jobs_data import PersistedJobDataFormat


Expand Down Expand Up @@ -266,9 +271,68 @@ def test_save_job_result(data_format, result_data, expected_saved_data):
assert expected_file.read() == expected_saved_data


@pytest.mark.xfail(raises=ValueError)
@pytest.mark.parametrize("result_data", [{}, None])
def test_save_job_result_raises_error_empty_data(result_data):
def test_save_job_result_does_not_raise_error_empty_data(result_data):
with tempfile.TemporaryDirectory() as tmp_dir:
with patch.dict(os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir}):
with patch.dict(os.environ, {"AMZN_BRAKET_JOB_RESULTS_DIR": tmp_dir}):
save_job_result(result_data)


@pytest.mark.parametrize(
"first_result_data,"
"first_data_format,"
"second_result_data,"
"second_data_format,"
"expected_result_data",
(
(
"hello",
PersistedJobDataFormat.PLAINTEXT,
"goodbye",
PersistedJobDataFormat.PLAINTEXT,
{"result": "goodbye"},
),
(
"hello",
PersistedJobDataFormat.PLAINTEXT,
"goodbye",
PersistedJobDataFormat.PICKLED_V4,
{"result": "goodbye"},
),
("hello", PersistedJobDataFormat.PICKLED_V4, "goodbye", None, {"result": "goodbye"}),
(
# not json serializable
PersistedJobDataFormat,
PersistedJobDataFormat.PICKLED_V4,
{"other_field": "value"},
None,
{"result": PersistedJobDataFormat, "other_field": "value"},
),
),
)
def test_update_result_data(
first_result_data,
first_data_format,
second_result_data,
second_data_format,
expected_result_data,
):
with tempfile.TemporaryDirectory() as tmp_dir:
with patch.dict(os.environ, {"AMZN_BRAKET_JOB_RESULTS_DIR": tmp_dir}):
save_job_result(first_result_data, first_data_format)
save_job_result(second_result_data, second_data_format)

assert load_job_result() == expected_result_data


def test_update_pickled_results_as_plaintext_error():
with tempfile.TemporaryDirectory() as tmp_dir:
with patch.dict(os.environ, {"AMZN_BRAKET_JOB_RESULTS_DIR": tmp_dir}):
save_job_result(np.arange(5), PersistedJobDataFormat.PICKLED_V4)

cannot_convert_pickled_to_plaintext = (
"Cannot update results object serialized with "
"pickled_v4 using data format plaintext."
)
with pytest.raises(TypeError, match=cannot_convert_pickled_to_plaintext):
save_job_result("hello", PersistedJobDataFormat.PLAINTEXT)

0 comments on commit 5c1e4e7

Please sign in to comment.