Skip to content

Commit

Permalink
fix: Removed dirs_exist_ok parameter as it's not backwards compatible (
Browse files Browse the repository at this point in the history
…#1170)

* fix: Removed dirs_exist_ok parameter as it's not backwards compatible

* Added unit tests and fixed bug

* Removed unneeded import

Co-authored-by: Rosie Zou <[email protected]>
  • Loading branch information
ivanmkc and rosiezou authored Jun 23, 2022
1 parent 9ef057a commit 50d4129
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 1 deletion.
6 changes: 5 additions & 1 deletion google/cloud/aiplatform/utils/source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ def make_package(self, package_directory: str) -> str:
fp.write(setup_py_output)

if os.path.isdir(self.script_path):
shutil.copytree(self.script_path, trainer_path, dirs_exist_ok=True)
# Remove destination path if it already exists
shutil.rmtree(trainer_path)

# Copy folder recursively
shutil.copytree(src=self.script_path, dst=trainer_path)
else:
# The module that will contain the script
script_out_path = trainer_path / f"{self.task_module_name}.py"
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/aiplatform/test_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
#

from importlib import reload
import filecmp
import json
import os
import pytest
import tempfile

from google.cloud.aiplatform.training_utils import environment_variables
from google.cloud.aiplatform.utils import source_utils
from unittest import mock

_TEST_TRAINING_DATA_URI = "gs://training-data-uri"
Expand Down Expand Up @@ -203,3 +206,77 @@ def test_http_handler_port(self):
def test_http_handler_port_none(self):
reload(environment_variables)
assert environment_variables.http_handler_port is None

@pytest.fixture()
def mock_temp_file_name(self):
# Create random files
# tmpdirname = tempfile.TemporaryDirectory()
file = tempfile.NamedTemporaryFile()

with open(file.name, "w") as handle:
handle.write("test")

yield file.name

file.close()

def test_package_file(self, mock_temp_file_name):
# Test that the packager properly copies the source file to the destination file

packager = source_utils._TrainingScriptPythonPackager(
script_path=mock_temp_file_name
)

with tempfile.TemporaryDirectory() as destination_directory_name:
_ = packager.make_package(package_directory=destination_directory_name)

# Check that contents of source_distribution_path is the same as destination_directory_name
destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}/{packager.task_module_name}.py"

assert filecmp.cmp(
mock_temp_file_name, destination_inner_path, shallow=False
)

@pytest.fixture()
def mock_temp_folder_name(self):
# Create random folder
folder = tempfile.TemporaryDirectory()

file = tempfile.NamedTemporaryFile(dir=folder.name)

# Create random file in the folder
with open(file.name, "w") as handle:
handle.write("test")

yield folder.name

file.close()

folder.cleanup()

def test_package_folder(self, mock_temp_folder_name):
# Test that the packager properly copies the source folder to the destination folder

packager = source_utils._TrainingScriptPythonPackager(
script_path=mock_temp_folder_name
)

with tempfile.TemporaryDirectory() as destination_directory_name:
# Add an existing file into the destination directory to check if it gets deleted
existing_file = tempfile.NamedTemporaryFile(dir=destination_directory_name)

with open(existing_file.name, "w") as handle:
handle.write("existing")

_ = packager.make_package(package_directory=destination_directory_name)

# Check that contents of source_distribution_path is the same as destination_directory_name
destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}"

dcmp = filecmp.dircmp(mock_temp_folder_name, destination_inner_path)

assert len(dcmp.diff_files) == 0
assert len(dcmp.left_only) == 0
assert len(dcmp.right_only) == 0

existing_file.close()

0 comments on commit 50d4129

Please sign in to comment.