Skip to content

Commit

Permalink
update env variables (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajberdy authored Aug 27, 2021
1 parent 5b06e24 commit 8e48544
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/braket/jobs/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def save_job_checkpoint(
"""
if not checkpoint_data:
raise ValueError("checkpoint_data can not be empty")
checkpoint_directory = os.environ["CHECKPOINT_DIR"]
job_name = os.environ["JOB_NAME"]
checkpoint_directory = os.environ["AMZN_BRAKET_CHECKPOINT_DIR"]
job_name = os.environ["AMZN_BRAKET_JOB_NAME"]
checkpoint_file_path = (
f"{checkpoint_directory}/{job_name}_{checkpoint_file_suffix}.json"
if checkpoint_file_suffix
Expand Down Expand Up @@ -88,7 +88,7 @@ def load_job_checkpoint(job_name: str, checkpoint_file_suffix: str = "") -> Dict
ValueError: If the data stored in the checkpoint file can't be deserialized (possibly due to
corruption).
"""
checkpoint_directory = os.environ["CHECKPOINT_DIR"]
checkpoint_directory = os.environ["AMZN_BRAKET_CHECKPOINT_DIR"]
checkpoint_file_path = (
f"{checkpoint_directory}/{job_name}_{checkpoint_file_suffix}.json"
if checkpoint_file_suffix
Expand Down Expand Up @@ -125,7 +125,7 @@ def save_job_result(
"""
if not result_data:
raise ValueError("result_data can not be empty")
result_directory = os.environ["OUTPUT_DIR"]
result_directory = os.environ["AMZN_BRAKET_JOB_RESULTS_DIR"]
result_path = f"{result_directory}/results.json"
with open(result_path, "w") as f:
serialized_data = serialize_values(result_data or {}, data_format)
Expand Down
28 changes: 20 additions & 8 deletions test/unit_tests/braket/jobs/test_data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def test_save_job_checkpoint(
job_name, file_suffix, data_format, checkpoint_data, expected_saved_data
):
with tempfile.TemporaryDirectory() as tmp_dir:
with patch.dict(os.environ, {"CHECKPOINT_DIR": tmp_dir, "JOB_NAME": job_name}):
with patch.dict(
os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name}
):
save_job_checkpoint(checkpoint_data, file_suffix, data_format)

expected_file_location = (
Expand All @@ -85,7 +87,9 @@ def test_save_job_checkpoint(
def test_save_job_checkpoint_raises_error_empty_data(checkpoint_data):
job_name = "foo"
with tempfile.TemporaryDirectory() as tmp_dir:
with patch.dict(os.environ, {"CHECKPOINT_DIR": tmp_dir, "JOB_NAME": job_name}):
with patch.dict(
os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name}
):
save_job_checkpoint(checkpoint_data)


Expand Down Expand Up @@ -141,7 +145,9 @@ def test_load_job_checkpoint(
with open(file_path, "w") as f:
f.write(saved_data)

with patch.dict(os.environ, {"CHECKPOINT_DIR": tmp_dir, "JOB_NAME": job_name}):
with patch.dict(
os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name}
):
loaded_data = load_job_checkpoint(job_name, file_suffix)
assert loaded_data == expected_checkpoint_data

Expand All @@ -155,7 +161,9 @@ def test_load_job_checkpoint_raises_error_file_not_exists():
with open(file_path, "w") as _:
pass

with patch.dict(os.environ, {"CHECKPOINT_DIR": tmp_dir, "JOB_NAME": job_name}):
with patch.dict(
os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name}
):
load_job_checkpoint(job_name, "wrong_suffix")


Expand All @@ -182,7 +190,9 @@ def test_load_job_checkpoint_raises_error_corrupted_data():
)
)

with patch.dict(os.environ, {"CHECKPOINT_DIR": tmp_dir, "JOB_NAME": job_name}):
with patch.dict(
os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name}
):
load_job_checkpoint(job_name, file_suffix)


Expand All @@ -202,7 +212,9 @@ def test_save_and_load_job_checkpoint():
"none_value": None,
"nested_dict": {"a": {"b": False}},
}
with patch.dict(os.environ, {"CHECKPOINT_DIR": tmp_dir, "JOB_NAME": job_name}):
with patch.dict(
os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name}
):
save_job_checkpoint(data, data_format=PersistedJobDataFormat.PICKLED_V4)
retrieved = load_job_checkpoint(job_name)
assert retrieved == data
Expand Down Expand Up @@ -246,7 +258,7 @@ def test_save_and_load_job_checkpoint():
)
def test_save_job_result(data_format, result_data, expected_saved_data):
with tempfile.TemporaryDirectory() as tmp_dir:
with patch.dict(os.environ, {"OUTPUT_DIR": tmp_dir}):
with patch.dict(os.environ, {"AMZN_BRAKET_JOB_RESULTS_DIR": tmp_dir}):
save_job_result(result_data, data_format)

expected_file_location = f"{tmp_dir}/results.json"
Expand All @@ -258,5 +270,5 @@ def test_save_job_result(data_format, result_data, expected_saved_data):
@pytest.mark.parametrize("result_data", [{}, None])
def test_save_job_result_raises_error_empty_data(result_data):
with tempfile.TemporaryDirectory() as tmp_dir:
with patch.dict(os.environ, {"CHECKPOINT_DIR": tmp_dir}):
with patch.dict(os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir}):
save_job_result(result_data)

0 comments on commit 8e48544

Please sign in to comment.