Skip to content

Commit

Permalink
feature: process source_dir (#30)
Browse files Browse the repository at this point in the history
* feature: process source_dir

* Fixing windows issues

* Change test_name, windows fix try

* Formatting

* Error formatting, Entrypoint absolute condition check

* Update docstring

* Format

* Refactor, remove unnecessary code

* Validate upload call
  • Loading branch information
virajvchaudhari authored Aug 24, 2021
1 parent 611bddc commit 31de7c0
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 32 deletions.
89 changes: 64 additions & 25 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create(
aws_session: AwsSession,
entry_point: str,
device_arn: str,
source_dir: str,
source_dir: str = None,
# TODO: Replace with the correct default image name.
# This image_uri will be retrieved from `image_uris.retreive()` which will a different file
# in the `jobs` folder and the function defined in it.
Expand Down Expand Up @@ -87,15 +87,22 @@ def create(
aws_session (AwsSession): AwsSession to connect to AWS with.
entry_point (str): str specifying the 'module' or 'module:method' to be executed as
an entry point for the job.
an entry point for the job. The entry_point should be specified relative to
the source_dir. If source_dir is not provided then all the required modules
for execution of entry_point should be within the folder containing the
entry_point script.
For example,
if `entry_point = foo.bar.my_script:func`. Then all the required modules
should be present within the `foo` or `bar` folder.
device_arn (str): ARN for the AWS device which will be primarily
accessed for the execution of this job.
source_dir (str): Path (absolute, relative or an S3 URI) to a directory with any
other source code dependencies aside from the entry point file. If `source_dir`
is an S3 URI, it must point to a tar.gz file. Structure within this directory are
preserved when executing on Amazon Braket.
preserved when executing on Amazon Braket. Default = `None`.
image_uri (str): str specifying the ECR image to use for executing the job.
`image_uris.retrieve()` function may be used for retrieving the ECR image uris
Expand Down Expand Up @@ -197,8 +204,9 @@ def create(
"checkpointConfig"
]["s3Uri"]
aws_session.copy_s3_directory(checkpoints_to_copy, checkpoint_config.s3Uri)
AwsQuantumJob._validate_entry_point(entry_point)

AwsQuantumJob._process_source_dir(
entry_point,
source_dir,
aws_session,
code_location,
Expand Down Expand Up @@ -478,27 +486,6 @@ def __eq__(self, other) -> bool:
def __hash__(self) -> int:
return hash(self.arn)

@staticmethod
def _process_source_dir(source_dir, aws_session, code_location):
# TODO: check with product about copy in s3 behavior
if source_dir.startswith("s3://"):
if not source_dir.endswith(".tar.gz"):
raise ValueError(
f"If source_dir is an S3 URI, it must point to a tar.gz file. "
f"Not a valid S3 URI for parameter `source_dir`: {source_dir}"
)
aws_session.copy_s3_object(source_dir, f"{code_location}/source.tar.gz")
else:
with tempfile.TemporaryDirectory() as tmpdir:
try:
with tarfile.open(f"{tmpdir}/source.tar.gz", "w:gz") as tar:
tar.add(source_dir, arcname=os.path.basename(source_dir))
except FileNotFoundError:
raise ValueError(f"Source directory not found: {source_dir}")
aws_session.upload_to_s3(
f"{tmpdir}/source.tar.gz", f"{code_location}/source.tar.gz"
)

@staticmethod
def _validate_entry_point(entry_point):
module, _, function_name = entry_point.partition(":")
Expand All @@ -508,3 +495,55 @@ def _validate_entry_point(entry_point):
raise ValueError(f"Entry point module not found: '{module}'")
if function_name and not hasattr(module, function_name):
raise ValueError(f"Entry function '{function_name}' not found in module '{module}'.")

@staticmethod
def _process_source_dir(entry_point, source_dir, aws_session, code_location):
if source_dir:
if source_dir.startswith("s3://"):
AwsQuantumJob._process_s3_source_dir(aws_session, source_dir, code_location)
else:
AwsQuantumJob._process_local_source_dir(
aws_session, entry_point, source_dir, code_location
)
else:
AwsQuantumJob._source_dir_not_provided(aws_session, entry_point, code_location)

@staticmethod
def _process_s3_source_dir(aws_session, source_dir, code_location):
if not source_dir.endswith(".tar.gz"):
raise ValueError(
f"If source_dir is an S3 URI, it must point to a tar.gz file. "
f"Not a valid S3 URI for parameter `source_dir`: {source_dir}"
)
aws_session.copy_s3_object(source_dir, f"{code_location}/source.tar.gz")

@staticmethod
def _process_local_source_dir(aws_session, entry_point, source_dir, code_location):
module, _, func = entry_point.partition(":")
entry_file = f"{module.replace('.', '/')}.py"

if not os.path.abspath(entry_file).startswith(os.path.abspath(source_dir)):
raise ValueError(
f"`Entrypoint`: {entry_point} should be " f"within the `source_dir`: {source_dir}"
)

AwsQuantumJob._validate_entry_point(entry_point)
AwsQuantumJob._tar_and_upload_to_code_location(aws_session, source_dir, code_location)

@staticmethod
def _source_dir_not_provided(aws_session, entry_point, code_location):
AwsQuantumJob._validate_entry_point(entry_point)
module, _, func = entry_point.partition(":")
upload_dir = module.split(".")[0]

AwsQuantumJob._tar_and_upload_to_code_location(aws_session, upload_dir, code_location)

@staticmethod
def _tar_and_upload_to_code_location(aws_session, source_dir, code_location):
with tempfile.TemporaryDirectory() as temp_dir:
try:
with tarfile.open(f"{temp_dir}/source.tar.gz", "w:gz") as tar:
tar.add(source_dir, arcname=os.path.basename(source_dir))
except FileNotFoundError:
raise ValueError(f"Source directory not found: {source_dir}")
aws_session.upload_to_s3(f"{temp_dir}/source.tar.gz", f"{code_location}/source.tar.gz")
64 changes: 57 additions & 7 deletions test/unit_tests/braket/aws/test_aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,8 @@ def s3_prefix(job_name):
return f"{job_name}/non-default"


@pytest.fixture(params=["local_source", "s3_source", "working_directory"])
@pytest.fixture(params=["local_source", "s3_source"])
def source_dir(request, bucket, s3_prefix):
if request.param == "working_directory":
return "."
if request.param == "local_source":
return "test-source-dir"
elif request.param == "s3_source":
Expand Down Expand Up @@ -601,8 +599,8 @@ def test_create_job(
generate_get_job_response,
):
mock_time.return_value = datetime.datetime.now().timestamp()
with tempfile.TemporaryDirectory() as tempdir:
os.chdir(tempdir)
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
os.mkdir("test-source-dir")
if source_dir == ".":
os.chdir("test-source-dir")
Expand Down Expand Up @@ -710,11 +708,10 @@ def test_create_job_source_dir_not_found(
with pytest.raises(ValueError) as e:
AwsQuantumJob.create(
aws_session=aws_session,
entry_point=entry_point,
entry_point="fake-source-dir.test_script:func",
device_arn=device_arn,
source_dir=fake_source_dir,
)

assert str(e.value) == f"Source directory not found: {fake_source_dir}"


Expand All @@ -740,6 +737,59 @@ def test_create_job_source_dir_s3_but_not_tar(
)


def test_source_dir_not_in_entry_point_name(entry_point, aws_session, device_arn):
source_dir = "other-source-dir"

with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
os.mkdir("other-source-dir")
with pytest.raises(ValueError) as e:
AwsQuantumJob.create(
aws_session=aws_session,
entry_point=entry_point,
device_arn=device_arn,
source_dir=source_dir,
)
os.chdir("..")

assert (
str(e.value)
== f"`Entrypoint`: {entry_point} should be within the `source_dir`: {source_dir}"
)


@patch("importlib.import_module")
def test_entry_point_when_source_dir_not_provided(
mock_import, quantum_job, code_location, aws_session, entry_point
):
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)

module, _, func = entry_point.partition(":")
dir_data = module.split(".")
folder_path, file_name = "/".join(dir_data[:-1]), f"{dir_data[-1]}.py"

os.mkdir(folder_path)

with open(f"{folder_path}/__init__.py", "w") as f:
pass

with open(f"{folder_path}/{file_name}", "w") as f:
f.write("def func(): \n\tpass")

quantum_job._process_source_dir(
aws_session=aws_session,
entry_point=entry_point,
code_location=code_location,
source_dir=None,
)

args_list = aws_session.upload_to_s3.call_args_list
assert "source.tar.gz" in args_list[0][0][0] and code_location in args_list[0][0][1]
aws_session.upload_to_s3.assert_called_once()
os.chdir("..")


@patch("braket.aws.aws_quantum_job.AwsQuantumJob._validate_entry_point")
def test_copy_checkpoints(
mock_validate_entry_point,
Expand Down

0 comments on commit 31de7c0

Please sign in to comment.