Skip to content

Commit

Permalink
test: add full coverage (#731)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajberdy authored Oct 9, 2023
1 parent a7e799c commit c9fd192
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
2 changes: 0 additions & 2 deletions src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,6 @@ def _process_input_data(input_data):
else (multiple matches, possibly including exact):
cwd/prefix_match -> channel/prefix_match, for each match
"""
from braket.aws import AwsSession

input_data = input_data or {}
if not isinstance(input_data, dict):
input_data = {"input": input_data}
Expand Down
53 changes: 53 additions & 0 deletions test/unit_tests/braket/jobs/test_hybrid_job.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import ast
import importlib
import re
import tempfile
from logging import getLogger
from pathlib import Path
Expand All @@ -7,11 +9,13 @@

import job_module
import pytest
from cloudpickle import cloudpickle

from braket.aws import AwsQuantumJob
from braket.devices import Devices
from braket.jobs import hybrid_job
from braket.jobs.config import CheckpointConfig, InstanceConfig, OutputDataConfig, StoppingCondition
from braket.jobs.hybrid_job import _serialize_entry_point
from braket.jobs.local import LocalQuantumJob


Expand Down Expand Up @@ -158,6 +162,41 @@ def my_entry(a, b: int, c=0, d: float = 1.0, **extras) -> str:
mock_stdout.write.assert_any_call(s3_not_linked)


@patch("time.time", return_value=123.0)
@patch("builtins.open")
@patch("tempfile.TemporaryDirectory")
@patch.object(AwsQuantumJob, "create")
def test_decorator_non_dict_input(mock_create, mock_tempdir, _mock_open, mock_time):
input_prefix = "my_input"

@hybrid_job(device=None, input_data=input_prefix)
def my_entry():
return "my entry return value"

mock_tempdir_name = "job_temp_dir_00000"
mock_tempdir.return_value.__enter__.return_value = mock_tempdir_name

source_module = mock_tempdir_name
entry_point = f"{mock_tempdir_name}.entry_point:my_entry"
wait_until_complete = False

device = "local:none/none"

my_entry()

mock_create.assert_called_with(
device=device,
source_module=source_module,
entry_point=entry_point,
wait_until_complete=wait_until_complete,
job_name="my-entry-123000",
hyperparameters={},
logger=getLogger("braket.jobs.hybrid_job"),
input_data=input_prefix,
)
assert mock_tempdir.return_value.__exit__.called


@patch("time.time", return_value=123.0)
@patch("builtins.open")
@patch("tempfile.TemporaryDirectory")
Expand Down Expand Up @@ -340,3 +379,17 @@ def fails_serialization():
)
with pytest.raises(RuntimeError, match=serialization_failed):
fails_serialization()


def test_serialization_wrapping():
def my_entry(*args, **kwargs):
print("something with \" and ' and \n")
return args, kwargs

args, kwargs = (1, "two"), {"three": 3}
template = _serialize_entry_point(my_entry, args, kwargs)
pickled_str = re.search(r"(?s)cloudpickle.loads\((.*?)\)\ndef my_entry", template).group(1)
byte_str = ast.literal_eval(pickled_str)

recovered = cloudpickle.loads(byte_str)
assert recovered() == (args, kwargs)

0 comments on commit c9fd192

Please sign in to comment.