Skip to content

Commit

Permalink
Added unit tests and fixed bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed May 23, 2022
1 parent 89f7490 commit dad3781
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
3 changes: 0 additions & 3 deletions google/cloud/aiplatform/utils/source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ def make_package(self, package_directory: str) -> str:
# Remove destination path if it already exists
shutil.rmtree(trainer_path)

# Create destination path
os.makedirs(trainer_path)

# Copy folder recursively
shutil.copytree(src=self.script_path, dst=trainer_path)
else:
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/aiplatform/test_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
#

from importlib import reload
import filecmp
import glob
import json
import os
import pytest
import tempfile

from google.cloud.aiplatform.training_utils import environment_variables
from google.cloud.aiplatform.utils import source_utils
from unittest import mock

_TEST_TRAINING_DATA_URI = "gs://training-data-uri"
Expand Down Expand Up @@ -203,3 +207,77 @@ def test_http_handler_port(self):
def test_http_handler_port_none(self):
reload(environment_variables)
assert environment_variables.http_handler_port is None

@pytest.fixture()
def mock_temp_file_name(self):
# Create random files
# tmpdirname = tempfile.TemporaryDirectory()
file = tempfile.NamedTemporaryFile()

with open(file.name, "w") as handle:
handle.write("test")

yield file.name

file.close()

def test_package_file(self, mock_temp_file_name):
# Test that the packager properly copies the source file to the destination file

packager = source_utils._TrainingScriptPythonPackager(
script_path=mock_temp_file_name
)

with tempfile.TemporaryDirectory() as destination_directory_name:
_ = packager.make_package(package_directory=destination_directory_name)

# Check that contents of source_distribution_path is the same as destination_directory_name
destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}/{packager.task_module_name}.py"

assert filecmp.cmp(
mock_temp_file_name, destination_inner_path, shallow=False
)

@pytest.fixture()
def mock_temp_folder_name(self):
# Create random folder
folder = tempfile.TemporaryDirectory()

file = tempfile.NamedTemporaryFile(dir=folder.name)

# Create random file in the folder
with open(file.name, "w") as handle:
handle.write("test")

yield folder.name

file.close()

folder.cleanup()

def test_package_folder(self, mock_temp_folder_name):
# Test that the packager properly copies the source folder to the destination folder

packager = source_utils._TrainingScriptPythonPackager(
script_path=mock_temp_folder_name
)

with tempfile.TemporaryDirectory() as destination_directory_name:
# Add an existing file into the destination directory to check if it gets deleted
existing_file = tempfile.NamedTemporaryFile(dir=destination_directory_name)

with open(existing_file.name, "w") as handle:
handle.write("existing")

_ = packager.make_package(package_directory=destination_directory_name)

# Check that contents of source_distribution_path is the same as destination_directory_name
destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}"

dcmp = filecmp.dircmp(mock_temp_folder_name, destination_inner_path)

assert len(dcmp.diff_files) == 0
assert len(dcmp.left_only) == 0
assert len(dcmp.right_only) == 0

existing_file.close()

0 comments on commit dad3781

Please sign in to comment.