Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: job helper functions #720

Merged
merged 5 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from braket.aws import AwsDevice, AwsQuantumJob
from braket.circuits import Circuit
from braket.devices import Devices
from braket.jobs import save_job_result
from braket.jobs import get_job_device_arn, save_job_result


def run_job():
device = AwsDevice(os.environ.get("AMZN_BRAKET_DEVICE_ARN"))
device = AwsDevice(get_job_device_arn())

bell = Circuit().h(0).cnot(0, 1)
num_tasks = 10
Expand Down
8 changes: 8 additions & 0 deletions src/braket/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,12 @@
save_job_checkpoint,
save_job_result,
)
from braket.jobs.environment_variables import ( # noqa: F401
get_checkpoint_dir,
get_hyperparameters,
get_input_data_dir,
get_job_device_arn,
get_job_name,
get_results_dir,
)
from braket.jobs.image_uris import Framework, retrieve_image # noqa: F401
10 changes: 5 additions & 5 deletions src/braket/jobs/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import os
from typing import Any, Dict

from braket.jobs.environment_variables import get_checkpoint_dir, get_job_name, get_results_dir
from braket.jobs.serialization import deserialize_values, serialize_values
from braket.jobs_data import PersistedJobData, PersistedJobDataFormat

Expand Down Expand Up @@ -49,8 +49,8 @@ def save_job_checkpoint(
"""
if not checkpoint_data:
raise ValueError("The checkpoint_data argument cannot be empty.")
checkpoint_directory = os.environ["AMZN_BRAKET_CHECKPOINT_DIR"]
job_name = os.environ["AMZN_BRAKET_JOB_NAME"]
checkpoint_directory = get_checkpoint_dir()
job_name = get_job_name()
checkpoint_file_path = (
f"{checkpoint_directory}/{job_name}_{checkpoint_file_suffix}.json"
if checkpoint_file_suffix
Expand Down Expand Up @@ -90,7 +90,7 @@ def load_job_checkpoint(job_name: str, checkpoint_file_suffix: str = "") -> Dict
ValueError: If the data stored in the checkpoint file can't be deserialized (possibly due to
corruption).
"""
checkpoint_directory = os.environ["AMZN_BRAKET_CHECKPOINT_DIR"]
checkpoint_directory = get_checkpoint_dir()
checkpoint_file_path = (
f"{checkpoint_directory}/{job_name}_{checkpoint_file_suffix}.json"
if checkpoint_file_suffix
Expand Down Expand Up @@ -128,7 +128,7 @@ def save_job_result(
"""
if not result_data:
raise ValueError("The result_data argument cannot be empty.")
result_directory = os.environ["AMZN_BRAKET_JOB_RESULTS_DIR"]
result_directory = get_results_dir()
result_path = f"{result_directory}/results.json"
with open(result_path, "w") as f:
serialized_data = serialize_values(result_data or {}, data_format)
Expand Down
71 changes: 71 additions & 0 deletions src/braket/jobs/environment_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import json
import os
from typing import Dict


def get_job_name() -> str:
"""
Get the name of the current job.

Returns:
str: The name of the job if in a job, else an empty string.
"""
return os.getenv("AMZN_BRAKET_JOB_NAME", "")


def get_job_device_arn() -> str:
"""
Get the device ARN of the current job. If not in a job, default to "local:none/none".

Returns:
str: The device ARN of the current job or "local:none/none".
"""
return os.getenv("AMZN_BRAKET_DEVICE_ARN", "local:none/none")


def get_input_data_dir(channel: str = "input") -> str:
"""
Get the job input data directory.

Args:
channel (str): The name of the input channel. Default value
corresponds to the default input channel name, `input`.

Returns:
str: The input directory, defaulting to current working directory.
"""
input_dir = os.getenv("AMZN_BRAKET_INPUT_DIR", ".")
if input_dir != ".":
return f"{input_dir}/{channel}"
return input_dir


def get_results_dir() -> str:
"""
Get the job result directory.

Returns:
str: The results directory, defaulting to current working directory.
"""
return os.getenv("AMZN_BRAKET_JOB_RESULTS_DIR", ".")


def get_checkpoint_dir() -> str:
"""
Get the job checkpoint directory.
Returns:
str: The checkpoint directory, defaulting to current working directory.
"""
return os.getenv("AMZN_BRAKET_CHECKPOINT_DIR", ".")


def get_hyperparameters() -> Dict[str, str]:
"""
Get the job checkpoint directory.
Returns:
str: The checkpoint directory, defaulting to current working directory.
"""
if "AMZN_BRAKET_HP_FILE" in os.environ:
with open(os.getenv("AMZN_BRAKET_HP_FILE"), "r") as f:
return json.load(f)
return {}
16 changes: 8 additions & 8 deletions test/integ_tests/job_test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json
import os

from braket.aws import AwsDevice
from braket.circuits import Circuit
from braket.jobs import save_job_checkpoint, save_job_result
from braket.jobs import (
get_hyperparameters,
get_job_device_arn,
save_job_checkpoint,
save_job_result,
)
from braket.jobs_data import PersistedJobDataFormat


def start_here():
hp_file = os.environ["AMZN_BRAKET_HP_FILE"]
with open(hp_file, "r") as f:
hyperparameters = json.load(f)
hyperparameters = get_hyperparameters()

if hyperparameters["test_case"] == "completed":
completed_job_script()
Expand All @@ -40,7 +40,7 @@ def completed_job_script():
print("Test job started!!!!!")

# Use the device declared in the Orchestration Script
device = AwsDevice(os.environ["AMZN_BRAKET_DEVICE_ARN"])
device = AwsDevice(get_job_device_arn())

bell = Circuit().h(0).cnot(0, 1)
for count in range(5):
Expand Down
64 changes: 64 additions & 0 deletions test/unit_tests/braket/jobs/test_environment_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import json
import os
import tempfile
from unittest.mock import patch

from braket.jobs import (
get_checkpoint_dir,
get_hyperparameters,
get_input_data_dir,
get_job_device_arn,
get_job_name,
get_results_dir,
)


def test_job_name():
assert get_job_name() == ""
job_name = "my_job_name"
with patch.dict(os.environ, {"AMZN_BRAKET_JOB_NAME": job_name}):
assert get_job_name() == job_name


def test_job_device_arn():
assert get_job_device_arn() == "local:none/none"
device_arn = "my_device_arn"
with patch.dict(os.environ, {"AMZN_BRAKET_DEVICE_ARN": device_arn}):
assert get_job_device_arn() == device_arn


def test_input_data_dir():
assert get_input_data_dir() == "."
input_path = "my/input/path"
with patch.dict(os.environ, {"AMZN_BRAKET_INPUT_DIR": input_path}):
assert get_input_data_dir() == f"{input_path}/input"
channel_name = "my_channel"
assert get_input_data_dir(channel_name) == f"{input_path}/{channel_name}"


def test_results_dir():
assert get_results_dir() == "."
results_dir = "my_results_dir"
with patch.dict(os.environ, {"AMZN_BRAKET_JOB_RESULTS_DIR": results_dir}):
assert get_results_dir() == results_dir


def test_checkpoint_dir():
assert get_checkpoint_dir() == "."
checkpoint_dir = "my_checkpoint_dir"
with patch.dict(os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": checkpoint_dir}):
assert get_checkpoint_dir() == checkpoint_dir


def test_hyperparameters():
assert get_hyperparameters() == {}
hyperparameters = {
"a": "a_val",
"b": 2,
}
with tempfile.NamedTemporaryFile(mode="w+") as temp, patch.dict(
os.environ, {"AMZN_BRAKET_HP_FILE": temp.name}
):
json.dump(hyperparameters, temp)
temp.seek(0)
assert get_hyperparameters() == hyperparameters
Loading