diff --git a/google/cloud/aiplatform/utils/source_utils.py b/google/cloud/aiplatform/utils/source_utils.py index 2a13daf452..dc3c14a759 100644 --- a/google/cloud/aiplatform/utils/source_utils.py +++ b/google/cloud/aiplatform/utils/source_utils.py @@ -171,7 +171,11 @@ def make_package(self, package_directory: str) -> str: fp.write(setup_py_output) if os.path.isdir(self.script_path): - shutil.copytree(self.script_path, trainer_path, dirs_exist_ok=True) + # Remove destination path if it already exists + shutil.rmtree(trainer_path) + + # Copy folder recursively + shutil.copytree(src=self.script_path, dst=trainer_path) else: # The module that will contain the script script_out_path = trainer_path / f"{self.task_module_name}.py" diff --git a/tests/unit/aiplatform/test_training_utils.py b/tests/unit/aiplatform/test_training_utils.py index 99d49b7ead..24a60d95ff 100644 --- a/tests/unit/aiplatform/test_training_utils.py +++ b/tests/unit/aiplatform/test_training_utils.py @@ -16,11 +16,14 @@ # from importlib import reload +import filecmp import json import os import pytest +import tempfile from google.cloud.aiplatform.training_utils import environment_variables +from google.cloud.aiplatform.utils import source_utils from unittest import mock _TEST_TRAINING_DATA_URI = "gs://training-data-uri" @@ -203,3 +206,77 @@ def test_http_handler_port(self): def test_http_handler_port_none(self): reload(environment_variables) assert environment_variables.http_handler_port is None + + @pytest.fixture() + def mock_temp_file_name(self): + # Create random files + # tmpdirname = tempfile.TemporaryDirectory() + file = tempfile.NamedTemporaryFile() + + with open(file.name, "w") as handle: + handle.write("test") + + yield file.name + + file.close() + + def test_package_file(self, mock_temp_file_name): + # Test that the packager properly copies the source file to the destination file + + packager = source_utils._TrainingScriptPythonPackager( + script_path=mock_temp_file_name + ) + + with tempfile.TemporaryDirectory() as destination_directory_name: + _ = packager.make_package(package_directory=destination_directory_name) + + # Check that contents of source_distribution_path is the same as destination_directory_name + destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}/{packager.task_module_name}.py" + + assert filecmp.cmp( + mock_temp_file_name, destination_inner_path, shallow=False + ) + + @pytest.fixture() + def mock_temp_folder_name(self): + # Create random folder + folder = tempfile.TemporaryDirectory() + + file = tempfile.NamedTemporaryFile(dir=folder.name) + + # Create random file in the folder + with open(file.name, "w") as handle: + handle.write("test") + + yield folder.name + + file.close() + + folder.cleanup() + + def test_package_folder(self, mock_temp_folder_name): + # Test that the packager properly copies the source folder to the destination folder + + packager = source_utils._TrainingScriptPythonPackager( + script_path=mock_temp_folder_name + ) + + with tempfile.TemporaryDirectory() as destination_directory_name: + # Add an existing file into the destination directory to check if it gets deleted + existing_file = tempfile.NamedTemporaryFile(dir=destination_directory_name) + + with open(existing_file.name, "w") as handle: + handle.write("existing") + + _ = packager.make_package(package_directory=destination_directory_name) + + # Check that contents of source_distribution_path is the same as destination_directory_name + destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}" + + dcmp = filecmp.dircmp(mock_temp_folder_name, destination_inner_path) + + assert len(dcmp.diff_files) == 0 + assert len(dcmp.left_only) == 0 + assert len(dcmp.right_only) == 0 + + existing_file.close()