Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support dependency list for decorator hybrid jobs #764

Merged
merged 6 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/twine-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ jobs:
run: python -m pip install --user --upgrade wheel
- name: Install twine
run: python -m pip install --user --upgrade twine
- name: Install setuptools
run: python -m pip install --user --upgrade setuptools
- name: Build a binary wheel and a source tarball
run: python setup.py sdist bdist_wheel
- name: Check that long description will render correctly on PyPI.
Expand Down
28 changes: 20 additions & 8 deletions src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from logging import Logger, getLogger
from pathlib import Path
from types import ModuleType
from typing import Any, Dict, List
from typing import Any

import cloudpickle

Expand All @@ -47,7 +47,7 @@ def hybrid_job(
*,
device: str,
include_modules: str | ModuleType | Iterable[str | ModuleType] = None,
dependencies: str | Path = None,
dependencies: str | Path | list[str] = None,
local: bool = False,
job_name: str = None,
image_uri: str = None,
Expand Down Expand Up @@ -85,8 +85,10 @@ def hybrid_job(
modules to be included. Any references to members of these modules in the hybrid job
algorithm code will be serialized as part of the algorithm code. Default value `[]`

dependencies (str | Path): Path (absolute or relative) to a requirements.txt
file to be used for the hybrid job.
dependencies (str | Path | list[str]): Path (absolute or relative) to a requirements.txt
file, or alternatively a list of strings, with each string being a `requirement
specifier <https://pip.pypa.io/en/stable/reference/requirement-specifiers/
#requirement-specifiers>`_, to be used for the hybrid job.

local (bool): Whether to use local mode for the hybrid job. Default `False`

Expand Down Expand Up @@ -178,7 +180,7 @@ def job_wrapper(*args, **kwargs) -> Callable:
entry_point_file.write(template)

if dependencies:
shutil.copy(Path(dependencies).resolve(), temp_dir_path / "requirements.txt")
_process_dependencies(dependencies, temp_dir_path)

job_args = {
"device": device or "local:none/none",
Expand Down Expand Up @@ -241,6 +243,16 @@ def _validate_python_version(image_uri: str | None, aws_session: AwsSession | No
)


def _process_dependencies(dependencies: str | Path | list[str], temp_dir: Path) -> None:
if isinstance(dependencies, (str, Path)):
# requirements file
shutil.copy(Path(dependencies).resolve(), temp_dir / "requirements.txt")
else:
# list of packages
with open(temp_dir / "requirements.txt", "w") as f:
f.write("\n".join(dependencies))


class _IncludeModules:
def __init__(self, modules: str | ModuleType | Iterable[str | ModuleType] = None):
modules = modules or []
Expand Down Expand Up @@ -285,7 +297,7 @@ def wrapped_entry_point() -> Any:
)


def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> Dict:
def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> dict:
"""Capture function arguments as hyperparameters"""
signature = inspect.signature(entry_point)
bound_args = signature.bind(*args, **kwargs)
Expand Down Expand Up @@ -330,7 +342,7 @@ def _sanitize(hyperparameter: Any) -> str:
return sanitized


def _process_input_data(input_data: Dict) -> List[str]:
def _process_input_data(input_data: dict) -> list[str]:
"""
Create symlinks to data

Expand All @@ -344,7 +356,7 @@ def _process_input_data(input_data: Dict) -> List[str]:
if not isinstance(input_data, dict):
input_data = {"input": input_data}

def matches(prefix: str) -> List[str]:
def matches(prefix: str) -> list[str]:
return [
str(path) for path in Path(prefix).parent.iterdir() if str(path).startswith(str(prefix))
]
Expand Down
49 changes: 48 additions & 1 deletion test/unit_tests/braket/jobs/test_hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,53 @@ def my_entry():
assert mock_tempdir.return_value.__exit__.called


@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image")
@patch("time.time", return_value=123.0)
@patch("builtins.open")
@patch("tempfile.TemporaryDirectory")
@patch.object(AwsQuantumJob, "create")
def test_decorator_list_dependencies(
mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session
):
mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest"
dependency_list = ["dep_1", "dep_2", "dep_3"]

@hybrid_job(
device=None,
aws_session=aws_session,
dependencies=dependency_list,
)
def my_entry(c=0, d: float = 1.0, **extras):
return "my entry return value"

mock_tempdir_name = "job_temp_dir_00000"
mock_tempdir.return_value.__enter__.return_value = mock_tempdir_name

source_module = mock_tempdir_name
entry_point = f"{mock_tempdir_name}.entry_point:my_entry"
wait_until_complete = False

device = "local:none/none"

my_entry()

mock_create.assert_called_with(
device=device,
source_module=source_module,
entry_point=entry_point,
wait_until_complete=wait_until_complete,
job_name="my-entry-123000",
hyperparameters={"c": "0", "d": "1.0"},
logger=getLogger("braket.jobs.hybrid_job"),
aws_session=aws_session,
)
assert mock_tempdir.return_value.__exit__.called
_mock_open.assert_called_with(Path(mock_tempdir_name) / "requirements.txt", "w")
_mock_open.return_value.__enter__.return_value.write.assert_called_with(
"\n".join(dependency_list)
)


@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image")
@patch("time.time", return_value=123.0)
@patch("builtins.open")
Expand Down Expand Up @@ -487,7 +534,7 @@ def my_job():
),
(
"?" * 2600,
f"{'?'*2477}...{'?'*20}",
f"{'?' * 2477}...{'?' * 20}",
),
),
)
Expand Down
Loading