Skip to content

Commit

Permalink
KEP-2170: Add unit and Integration tests for model and dataset initia…
Browse files Browse the repository at this point in the history
…lizers (#2323)

* KEP-2170: Add unit and integration tests for model and dataset initializers

Signed-off-by: wei-chenglai <[email protected]>

* refactor tests

Signed-off-by: wei-chenglai <[email protected]>

---------

Signed-off-by: wei-chenglai <[email protected]>
seanlaii authored Jan 18, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 6d58ea9 commit e47d8f7
Showing 20 changed files with 628 additions and 4 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
@@ -86,6 +86,13 @@ jobs:
GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }}
JAX_JOB_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test

- name: Run initializer_v2 integration tests for Python 3.11+
if: ${{ matrix.python-version == '3.11' }}
run: |
pip install -r ./cmd/initializer_v2/dataset/requirements.txt
pip install -U './sdk_v2'
pytest ./test/integration/initializer_v2
- name: Collect volcano logs
if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }}
run: |
11 changes: 10 additions & 1 deletion .github/workflows/test-python.yaml
Original file line number Diff line number Diff line change
@@ -32,4 +32,13 @@ jobs:
pip install -U './sdk/python[huggingface]'
- name: Run unit test for training sdk
run: pytest ./sdk/python/kubeflow/training/api/training_client_test.py
run: |
pytest ./sdk/python/kubeflow/training/api/training_client_test.py
- name: Run Python unit tests for v2
run: |
pip install -U './sdk_v2'
export PYTHONPATH="${{ github.workspace }}:$PYTHONPATH"
pytest ./pkg/initializer_v2/model
pytest ./pkg/initializer_v2/dataset
pytest ./pkg/initializer_v2/utils
2 changes: 1 addition & 1 deletion cmd/initializer_v2/model/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
huggingface_hub==0.23.4
huggingface-hub>=0.27.0,<0.28
23 changes: 23 additions & 0 deletions pkg/initializer_v2/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os

import pytest


@pytest.fixture
def mock_env_vars():
"""Fixture to set and clean up environment variables"""
original_env = dict(os.environ)

def _set_env_vars(**kwargs):
for key, value in kwargs.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = str(value)
return os.environ

yield _set_env_vars

# Cleanup
os.environ.clear()
os.environ.update(original_env)
Empty file.
7 changes: 6 additions & 1 deletion pkg/initializer_v2/dataset/__main__.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,8 @@
level=logging.INFO,
)

if __name__ == "__main__":

def main():
logging.info("Starting dataset initialization")

try:
@@ -29,3 +30,7 @@
case _:
logging.error("STORAGE_URI must have the valid dataset provider")
raise Exception


if __name__ == "__main__":
main()
95 changes: 95 additions & 0 deletions pkg/initializer_v2/dataset/huggingface_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from unittest.mock import MagicMock, patch

import pytest
from kubeflow.training import DATASET_PATH

import pkg.initializer_v2.utils.utils as utils
from pkg.initializer_v2.dataset.huggingface import HuggingFace


# Test cases for config loading
@pytest.mark.parametrize(
"test_name, test_config, expected",
[
(
"Full config with token",
{"storage_uri": "hf://dataset/path", "access_token": "test_token"},
{"storage_uri": "hf://dataset/path", "access_token": "test_token"},
),
(
"Minimal config without token",
{"storage_uri": "hf://dataset/path"},
{"storage_uri": "hf://dataset/path", "access_token": None},
),
],
)
def test_load_config(test_name, test_config, expected):
"""Test config loading with different configurations"""
print(f"Running test: {test_name}")

huggingface_dataset_instance = HuggingFace()

with patch.object(utils, "get_config_from_env", return_value=test_config):
huggingface_dataset_instance.load_config()
assert (
huggingface_dataset_instance.config.storage_uri == expected["storage_uri"]
)
assert (
huggingface_dataset_instance.config.access_token == expected["access_token"]
)

print("Test execution completed")


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with token",
{
"config": {
"storage_uri": "hf://username/dataset-name",
"access_token": "test_token",
},
"should_login": True,
"expected_repo_id": "username/dataset-name",
},
),
(
"Successful download without token",
{
"config": {"storage_uri": "hf://org/dataset-v1", "access_token": None},
"should_login": False,
"expected_repo_id": "org/dataset-v1",
},
),
],
)
def test_download_dataset(test_name, test_case):
"""Test dataset download with different configurations"""

print(f"Running test: {test_name}")

huggingface_dataset_instance = HuggingFace()
huggingface_dataset_instance.config = MagicMock(**test_case["config"])

with patch("huggingface_hub.login") as mock_login, patch(
"huggingface_hub.snapshot_download"
) as mock_download:

# Execute download
huggingface_dataset_instance.download_dataset()

# Verify login behavior
if test_case["should_login"]:
mock_login.assert_called_once_with(test_case["config"]["access_token"])
else:
mock_login.assert_not_called()

# Verify download parameters
mock_download.assert_called_once_with(
repo_id=test_case["expected_repo_id"],
local_dir=DATASET_PATH,
repo_type="dataset",
)
print("Test execution completed")
71 changes: 71 additions & 0 deletions pkg/initializer_v2/dataset/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest.mock import MagicMock, patch

import pytest

from pkg.initializer_v2.dataset.__main__ import main


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with HuggingFace provider",
{
"storage_uri": "hf://dataset/path",
"access_token": "test_token",
"expected_error": None,
},
),
(
"Missing storage URI environment variable",
{
"storage_uri": None,
"access_token": None,
"expected_error": Exception,
},
),
(
"Invalid storage URI scheme",
{
"storage_uri": "invalid://dataset/path",
"access_token": None,
"expected_error": Exception,
},
),
],
)
def test_dataset_main(test_name, test_case, mock_env_vars):
"""Test main script with different scenarios"""
print(f"Running test: {test_name}")

# Setup mock environment variables
env_vars = {
"STORAGE_URI": test_case["storage_uri"],
"ACCESS_TOKEN": test_case["access_token"],
}
mock_env_vars(**env_vars)

# Setup mock HuggingFace instance
mock_hf_instance = MagicMock()

with patch(
"pkg.initializer_v2.dataset.__main__.HuggingFace",
return_value=mock_hf_instance,
) as mock_hf:

# Execute test
if test_case["expected_error"]:
with pytest.raises(test_case["expected_error"]):
main()
else:
main()

# Verify HuggingFace instance methods were called
mock_hf_instance.load_config.assert_called_once()
mock_hf_instance.download_dataset.assert_called_once()

# Verify HuggingFace class instantiation
if test_case["storage_uri"] and test_case["storage_uri"].startswith("hf://"):
mock_hf.assert_called_once()

print("Test execution completed")
Empty file.
7 changes: 6 additions & 1 deletion pkg/initializer_v2/model/__main__.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,8 @@
level=logging.INFO,
)

if __name__ == "__main__":

def main():
logging.info("Starting pre-trained model initialization")

try:
@@ -31,3 +32,7 @@
f"STORAGE_URI must have the valid model provider. STORAGE_URI: {storage_uri}"
)
raise Exception


if __name__ == "__main__":
main()
93 changes: 93 additions & 0 deletions pkg/initializer_v2/model/huggingface_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from unittest.mock import MagicMock, patch

import pytest
from kubeflow.training import MODEL_PATH

import pkg.initializer_v2.utils.utils as utils
from pkg.initializer_v2.model.huggingface import HuggingFace


# Test cases for config loading
@pytest.mark.parametrize(
"test_name, test_config, expected",
[
(
"Full config with token",
{"storage_uri": "hf://model/path", "access_token": "test_token"},
{"storage_uri": "hf://model/path", "access_token": "test_token"},
),
(
"Minimal config without token",
{"storage_uri": "hf://model/path"},
{"storage_uri": "hf://model/path", "access_token": None},
),
],
)
def test_load_config(test_name, test_config, expected):
"""Test config loading with different configurations"""
print(f"Running test: {test_name}")

huggingface_model_instance = HuggingFace()
with patch.object(utils, "get_config_from_env", return_value=test_config):
huggingface_model_instance.load_config()
assert huggingface_model_instance.config.storage_uri == expected["storage_uri"]
assert (
huggingface_model_instance.config.access_token == expected["access_token"]
)

print("Test execution completed")


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with token",
{
"config": {
"storage_uri": "hf://username/model-name",
"access_token": "test_token",
},
"should_login": True,
"expected_repo_id": "username/model-name",
},
),
(
"Successful download without token",
{
"config": {"storage_uri": "hf://org/model-v1", "access_token": None},
"should_login": False,
"expected_repo_id": "org/model-v1",
},
),
],
)
def test_download_model(test_name, test_case):
"""Test model download with different configurations"""

print(f"Running test: {test_name}")

huggingface_model_instance = HuggingFace()
huggingface_model_instance.config = MagicMock(**test_case["config"])

with patch("huggingface_hub.login") as mock_login, patch(
"huggingface_hub.snapshot_download"
) as mock_download:

# Execute download
huggingface_model_instance.download_model()

# Verify login behavior
if test_case["should_login"]:
mock_login.assert_called_once_with(test_case["config"]["access_token"])
else:
mock_login.assert_not_called()

# Verify download parameters
mock_download.assert_called_once_with(
repo_id=test_case["expected_repo_id"],
local_dir=MODEL_PATH,
allow_patterns=["*.json", "*.safetensors", "*.model"],
ignore_patterns=["*.msgpack", "*.h5", "*.bin", ".pt", ".pth"],
)
print("Test execution completed")
71 changes: 71 additions & 0 deletions pkg/initializer_v2/model/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest.mock import MagicMock, patch

import pytest

from pkg.initializer_v2.model.__main__ import main


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with HuggingFace provider",
{
"storage_uri": "hf://model/path",
"access_token": "test_token",
"expected_error": None,
},
),
(
"Missing storage URI environment variable",
{
"storage_uri": None,
"access_token": None,
"expected_error": Exception,
},
),
(
"Invalid storage URI scheme",
{
"storage_uri": "invalid://model/path",
"access_token": None,
"expected_error": Exception,
},
),
],
)
def test_model_main(test_name, test_case, mock_env_vars):
"""Test main script with different scenarios"""
print(f"Running test: {test_name}")

# Setup mock environment variables
env_vars = {
"STORAGE_URI": test_case["storage_uri"],
"ACCESS_TOKEN": test_case["access_token"],
}
mock_env_vars(**env_vars)

# Setup mock HuggingFace instance
mock_hf_instance = MagicMock()

with patch(
"pkg.initializer_v2.model.__main__.HuggingFace",
return_value=mock_hf_instance,
) as mock_hf:

# Execute test
if test_case["expected_error"]:
with pytest.raises(test_case["expected_error"]):
main()
else:
main()

# Verify HuggingFace instance methods were called
mock_hf_instance.load_config.assert_called_once()
mock_hf_instance.download_model.assert_called_once()

# Verify HuggingFace class instantiation
if test_case["storage_uri"] and test_case["storage_uri"].startswith("hf://"):
mock_hf.assert_called_once()

print("Test execution completed")
35 changes: 35 additions & 0 deletions pkg/initializer_v2/utils/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
from kubeflow.training import HuggingFaceDatasetConfig, HuggingFaceModelInputConfig

import pkg.initializer_v2.utils.utils as utils


@pytest.mark.parametrize(
"config_class,env_vars,expected",
[
(
HuggingFaceModelInputConfig,
{"STORAGE_URI": "hf://test", "ACCESS_TOKEN": "token"},
{"storage_uri": "hf://test", "access_token": "token"},
),
(
HuggingFaceModelInputConfig,
{"STORAGE_URI": "hf://test"},
{"storage_uri": "hf://test", "access_token": None},
),
(
HuggingFaceDatasetConfig,
{"STORAGE_URI": "hf://test", "ACCESS_TOKEN": "token"},
{"storage_uri": "hf://test", "access_token": "token"},
),
(
HuggingFaceDatasetConfig,
{"STORAGE_URI": "hf://test"},
{"storage_uri": "hf://test", "access_token": None},
),
],
)
def test_get_config_from_env(mock_env_vars, config_class, env_vars, expected):
mock_env_vars(**env_vars)
result = utils.get_config_from_env(config_class)
assert result == expected
Empty file added test/__init__.py
Empty file.
Empty file added test/integration/__init__.py
Empty file.
Empty file.
47 changes: 47 additions & 0 deletions test/integration/initializer_v2/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import shutil
import tempfile

import pytest


@pytest.fixture
def setup_temp_path(monkeypatch):
"""Creates temporary directory and patches path constant.
This fixture:
1. Creates a temporary directory
2. Allows configuration of path constant
3. Handles automatic cleanup after tests
4. Restores original environment state
Args:
monkeypatch: pytest fixture for modifying objects
Returns:
function: A configurator that accepts path_var (str) and returns temp_dir path
Usage:
def test_something(setup_temp_path):
temp_dir = setup_temp_path("MODEL_PATH")
# temp_dir is created and MODEL_PATH is patched
# cleanup happens automatically after test
"""
# Setup
original_env = dict(os.environ)
current_dir = os.path.dirname(os.path.abspath(__file__))
temp_dir = tempfile.mkdtemp(dir=current_dir)

def configure_path(path_var: str):
"""Configure path variable in kubeflow.training"""
import kubeflow.training as training

monkeypatch.setattr(training, path_var, temp_dir)
return temp_dir

yield configure_path

# Cleanup temp directory after test
shutil.rmtree(temp_dir, ignore_errors=True)
os.environ.clear()
os.environ.update(original_env)
74 changes: 74 additions & 0 deletions test/integration/initializer_v2/dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os
import runpy
from test.integration.initializer_v2.utils import verify_downloaded_files

import pytest

import pkg.initializer_v2.utils.utils as utils


class TestDatasetIntegration:
"""Integration tests for dataset initialization"""

@pytest.fixture(autouse=True)
def setup_teardown(self, setup_temp_path):
self.temp_dir = setup_temp_path("DATASET_PATH")

@pytest.mark.parametrize(
"test_name, provider, test_case",
[
# Public HuggingFace dataset test
(
"HuggingFace - Public dataset",
"huggingface",
{
"storage_uri": "hf://karpathy/tiny_shakespeare",
"access_token": None,
"expected_files": ["tiny_shakespeare.py"],
"expected_error": None,
},
),
(
"HuggingFace - Invalid dataset",
"huggingface",
{
"storage_uri": "hf://invalid/nonexistent-dataset",
"access_token": None,
"expected_files": None,
"expected_error": Exception,
},
),
(
"HuggingFace - Login Failure",
"huggingface",
{
"storage_uri": "hf://karpathy/tiny_shakespeare",
"access_token": "invalid token",
"expected_files": None,
"expected_error": Exception,
},
),
],
)
def test_dataset_download(self, test_name, provider, test_case):
"""Test end-to-end dataset download for different providers"""
print(f"Running Integration test for {provider}: {test_name}")

# Setup environment variables based on test case
os.environ[utils.STORAGE_URI_ENV] = test_case["storage_uri"]
expected_files = test_case.get("expected_files")

if test_case.get("access_token"):
os.environ["ACCESS_TOKEN"] = test_case["access_token"]

# Run the main script
if test_case["expected_error"]:
with pytest.raises(test_case["expected_error"]):
runpy.run_module(
"pkg.initializer_v2.dataset.__main__", run_name="__main__"
)
else:
runpy.run_module("pkg.initializer_v2.dataset.__main__", run_name="__main__")
verify_downloaded_files(self.temp_dir, expected_files)

print("Test execution completed")
80 changes: 80 additions & 0 deletions test/integration/initializer_v2/model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import runpy
from test.integration.initializer_v2.utils import verify_downloaded_files

import pytest

import pkg.initializer_v2.utils.utils as utils


class TestModelIntegration:
"""Integration tests for model initialization"""

@pytest.fixture(autouse=True)
def setup_teardown(self, setup_temp_path):
self.temp_dir = setup_temp_path("MODEL_PATH")

@pytest.mark.parametrize(
"test_name, provider, test_case",
[
# Public HuggingFace model test
(
"HuggingFace - Public model",
"huggingface",
{
"storage_uri": "hf://hf-internal-testing/tiny-random-bert",
"access_token": None,
"expected_files": [
"config.json",
"model.safetensors",
"tokenizer.json",
"tokenizer_config.json",
],
"expected_error": None,
},
),
(
"HuggingFace - Invalid model",
"huggingface",
{
"storage_uri": "hf://invalid/nonexistent-model",
"access_token": None,
"expected_files": None,
"expected_error": Exception,
},
),
(
"HuggingFace - Login failure",
"huggingface",
{
"storage_uri": "hf://hf-internal-testing/tiny-random-bert",
"access_token": "invalid token",
"expected_files": None,
"expected_error": Exception,
},
),
],
)
def test_model_download(self, test_name, provider, test_case):
"""Test end-to-end model download for different providers"""
print(f"Running Integration test for {provider}: {test_name}")

# Setup environment variables based on test case
os.environ[utils.STORAGE_URI_ENV] = test_case["storage_uri"]
expected_files = test_case.get("expected_files")

# Handle token/credentials
if test_case.get("access_token"):
os.environ["ACCESS_TOKEN"] = test_case["access_token"]

# Run the main script
if test_case["expected_error"]:
with pytest.raises(test_case["expected_error"]):
runpy.run_module(
"pkg.initializer_v2.model.__main__", run_name="__main__"
)
else:
runpy.run_module("pkg.initializer_v2.model.__main__", run_name="__main__")
verify_downloaded_files(self.temp_dir, expected_files)

print("Test execution completed")
9 changes: 9 additions & 0 deletions test/integration/initializer_v2/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os


def verify_downloaded_files(dir_path, expected_files):
"""Verify downloaded files"""
if expected_files:
actual_files = set(os.listdir(dir_path))
missing_files = set(expected_files) - actual_files
assert not missing_files, f"Missing expected files: {missing_files}"

0 comments on commit e47d8f7

Please sign in to comment.