-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feature/#167_copy_version_check
- Loading branch information
Showing
7 changed files
with
256 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
10 changes: 2 additions & 8 deletions
10
exasol_transformers_extension/utils/huggingface_hub_bucketfs_model_transfer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
exasol_transformers_extension/utils/huggingface_hub_bucketfs_model_transfer_sp.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from pathlib import Path | ||
|
||
from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation | ||
|
||
from exasol_transformers_extension.utils.model_factory_protocol import ModelFactoryProtocol | ||
from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploaderFactory | ||
from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory | ||
|
||
|
||
|
||
|
||
|
||
class HuggingFaceHubBucketFSModelTransferSP: | ||
""" | ||
Class for downloading a model using the Huggingface Transformers API, and loading it into the BucketFS | ||
using save_pretrained. | ||
:bucketfs_location: BucketFSLocation the model should be loaded to | ||
:model_name: Name of the model to be downloaded using Huggingface Transformers API | ||
:model_path: Path the model will be loaded into the BucketFS at | ||
:token: Huggingface token, only needed for private models | ||
:temporary_directory_factory: Optional. Default is TemporaryDirectoryFactory. Mainly change for testing. | ||
:bucketfs_model_uploader_factory: Optional. Default is BucketFSModelUploaderFactory. Mainly change for testing. | ||
""" | ||
def __init__(self, | ||
bucketfs_location: BucketFSLocation, | ||
model_name: str, | ||
model_path: Path, | ||
token: str, | ||
temporary_directory_factory: TemporaryDirectoryFactory = TemporaryDirectoryFactory(), | ||
bucketfs_model_uploader_factory: BucketFSModelUploaderFactory = BucketFSModelUploaderFactory()): | ||
self._token = token | ||
self._model_name = model_name | ||
self._temporary_directory_factory = temporary_directory_factory | ||
self._bucketfs_model_uploader = bucketfs_model_uploader_factory.create( | ||
model_path=model_path, | ||
bucketfs_location=bucketfs_location) | ||
self._tmpdir = temporary_directory_factory.create() | ||
self._tmpdir_name = Path(self._tmpdir.__enter__()) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __del__(self): | ||
self._tmpdir.cleanup() | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self._tmpdir.__exit__(exc_type, exc_val, exc_tb) | ||
|
||
def download_from_huggingface_hub(self, model_factory: ModelFactoryProtocol): | ||
""" | ||
Download a model from HuggingFace Hub into a temporary directory and save it with save_pretrained | ||
in temporary directory / pretrained . | ||
""" | ||
model = model_factory.from_pretrained(self._model_name, cache_dir=self._tmpdir_name / "cache", use_auth_token=self._token) | ||
model.save_pretrained(self._tmpdir_name / "pretrained" / self._model_name) | ||
|
||
def upload_to_bucketfs(self) -> Path: | ||
""" | ||
Upload the downloaded models into the BucketFS. | ||
returns: Path of the uploaded model in the BucketFS | ||
""" | ||
return self._bucketfs_model_uploader.upload_directory(self._tmpdir_name / "pretrained" / self._model_name) | ||
|
||
|
||
class HuggingFaceHubBucketFSModelTransferSPFactory: | ||
""" | ||
Class for creating a HuggingFaceHubBucketFSModelTransferSP object. | ||
""" | ||
def create(self, | ||
bucketfs_location: BucketFSLocation, | ||
model_name: str, | ||
model_path: Path, | ||
token: str) -> HuggingFaceHubBucketFSModelTransferSP: | ||
""" | ||
Creates a HuggingFaceHubBucketFSModelTransferSP object. | ||
:bucketfs_location: BucketFSLocation the model should be loaded to | ||
:model_name: Name of the model to be downloaded using Huggingface Transformers API | ||
:model_path: Path the model will be loaded into the BucketFS at | ||
:token: Huggingface token, only needed for private models | ||
returns: The created HuggingFaceHubBucketFSModelTransferSP object. | ||
""" | ||
return HuggingFaceHubBucketFSModelTransferSP(bucketfs_location=bucketfs_location, | ||
model_name=model_name, | ||
model_path=model_path, | ||
token=token) |
16 changes: 16 additions & 0 deletions
16
exasol_transformers_extension/utils/model_factory_protocol.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from pathlib import Path | ||
from typing import Protocol, Union, runtime_checkable | ||
|
||
import transformers | ||
|
||
|
||
@runtime_checkable | ||
class ModelFactoryProtocol(Protocol): | ||
""" | ||
Protocol for better type hints. | ||
""" | ||
def from_pretrained(self, model_name: str, cache_dir: Path, use_auth_token: str) -> transformers.PreTrainedModel: | ||
pass | ||
|
||
def save_pretrained(self, save_directory: Union[str, Path]): | ||
pass |
64 changes: 64 additions & 0 deletions
64
tests/integration_tests/without_db/utils/test_huggingface_hub_bucketfs_model_transfer_sp.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytest | ||
import tempfile | ||
from pathlib import Path | ||
from typing import Union | ||
from unittest.mock import create_autospec, MagicMock | ||
|
||
from transformers import AutoModel | ||
|
||
from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploader, \ | ||
BucketFSModelUploaderFactory | ||
from exasol_transformers_extension.utils.huggingface_hub_bucketfs_model_transfer_sp import ModelFactoryProtocol, \ | ||
HuggingFaceHubBucketFSModelTransferSP | ||
from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory | ||
from tests.utils.mock_cast import mock_cast | ||
|
||
from tests.utils.parameters import model_params | ||
|
||
|
||
class TestSetup: | ||
def __init__(self, bucketfs_location): | ||
self.bucketfs_location = bucketfs_location | ||
self.model_factory_mock: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol) | ||
self.temporary_directory_factory = TemporaryDirectoryFactory() | ||
self.bucketfs_model_uploader_factory_mock: Union[BucketFSModelUploaderFactory, MagicMock] = \ | ||
create_autospec(BucketFSModelUploaderFactory) | ||
self.bucketfs_model_uploader_mock: Union[BucketFSModelUploader, MagicMock] = \ | ||
create_autospec(BucketFSModelUploader) | ||
mock_cast(self.bucketfs_model_uploader_factory_mock.create).side_effect = [self.bucketfs_model_uploader_mock] | ||
|
||
self.token = "token" | ||
model_params_ = model_params.tiny_model | ||
self.model_name = model_params_ | ||
self.model_path = Path("test_model_path") | ||
self.downloader = HuggingFaceHubBucketFSModelTransferSP( | ||
bucketfs_location=self.bucketfs_location, | ||
model_path=self.model_path, | ||
model_name=self.model_name, | ||
token=self.token, | ||
temporary_directory_factory=self.temporary_directory_factory, | ||
bucketfs_model_uploader_factory=self.bucketfs_model_uploader_factory_mock | ||
) | ||
|
||
def reset_mocks(self): | ||
self.model_factory_mock.reset_mock() | ||
self.bucketfs_model_uploader_mock.reset_mock() | ||
|
||
|
||
def test_download_with_model(bucketfs_location): | ||
with tempfile.TemporaryDirectory() as folder: | ||
test_setup = TestSetup(bucketfs_location) | ||
base_model_factory: ModelFactoryProtocol = AutoModel | ||
test_setup.downloader.download_from_huggingface_hub(model_factory=base_model_factory) | ||
assert AutoModel.from_pretrained(test_setup.downloader._tmpdir_name / "pretrained" / test_setup.model_name) | ||
del test_setup.downloader | ||
|
||
|
||
def test_download_with_duplicate_model(bucketfs_location): | ||
with tempfile.TemporaryDirectory() as folder: | ||
test_setup = TestSetup(bucketfs_location) | ||
base_model_factory: ModelFactoryProtocol = AutoModel | ||
test_setup.downloader.download_from_huggingface_hub(model_factory=base_model_factory) | ||
test_setup.downloader.download_from_huggingface_hub(model_factory=base_model_factory) | ||
assert AutoModel.from_pretrained(test_setup.downloader._tmpdir_name / "pretrained" / test_setup.model_name) | ||
del test_setup.downloader |
File renamed without changes.
81 changes: 81 additions & 0 deletions
81
tests/unit_tests/utils/test_huggingface_hub_bucketfs_model_transfer_sp.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from pathlib import Path | ||
from typing import Union | ||
from unittest.mock import create_autospec, MagicMock, call | ||
|
||
from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation | ||
|
||
from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploader, \ | ||
BucketFSModelUploaderFactory | ||
from exasol_transformers_extension.utils.huggingface_hub_bucketfs_model_transfer_sp import ModelFactoryProtocol, \ | ||
HuggingFaceHubBucketFSModelTransferSP | ||
from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory | ||
from tests.utils.mock_cast import mock_cast | ||
|
||
from tests.utils.parameters import model_params | ||
|
||
|
||
class TestSetup: | ||
def __init__(self): | ||
self.bucketfs_location_mock: Union[BucketFSLocation, MagicMock] = create_autospec(BucketFSLocation) | ||
self.model_factory_mock: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol) | ||
self.temporary_directory_factory_mock: Union[TemporaryDirectoryFactory, MagicMock] = \ | ||
create_autospec(TemporaryDirectoryFactory) | ||
self.bucketfs_model_uploader_factory_mock: Union[BucketFSModelUploaderFactory, MagicMock] = \ | ||
create_autospec(BucketFSModelUploaderFactory) | ||
self.bucketfs_model_uploader_mock: Union[BucketFSModelUploader, MagicMock] = \ | ||
create_autospec(BucketFSModelUploader) | ||
mock_cast(self.bucketfs_model_uploader_factory_mock.create).side_effect = [self.bucketfs_model_uploader_mock] | ||
|
||
|
||
self.token = "token" | ||
model_params_ = model_params.tiny_model | ||
self.model_name = model_params_ | ||
self.model_path = Path("test_model_path") | ||
self.downloader = HuggingFaceHubBucketFSModelTransferSP( | ||
bucketfs_location=self.bucketfs_location_mock, | ||
model_path=self.model_path, | ||
model_name=self.model_name, | ||
token=self.token, | ||
temporary_directory_factory=self.temporary_directory_factory_mock, | ||
bucketfs_model_uploader_factory=self.bucketfs_model_uploader_factory_mock | ||
) | ||
|
||
def reset_mocks(self): | ||
self.bucketfs_location_mock.reset_mock() | ||
self.temporary_directory_factory_mock.reset_mock() | ||
self.model_factory_mock.reset_mock() | ||
self.bucketfs_model_uploader_mock.reset_mock() | ||
self.bucketfs_model_uploader_factory_mock.reset_mock() | ||
|
||
|
||
def test_init(): | ||
test_setup = TestSetup() | ||
assert test_setup.temporary_directory_factory_mock.mock_calls == [call.create(), | ||
call.create().__enter__(), | ||
call.create().__enter__().__fspath__()] \ | ||
and test_setup.model_factory_mock.mock_calls == [] \ | ||
and test_setup.bucketfs_location_mock.mock_calls == [] \ | ||
and mock_cast(test_setup.bucketfs_model_uploader_factory_mock.create).mock_calls == [ | ||
call.create(model_path=test_setup.model_path, bucketfs_location=test_setup.bucketfs_location_mock) | ||
] | ||
|
||
|
||
def test_download_function_call(): | ||
test_setup = TestSetup() | ||
test_setup.downloader.download_from_huggingface_hub(model_factory=test_setup.model_factory_mock) | ||
cache_dir = mock_cast(test_setup.temporary_directory_factory_mock.create().__enter__).return_value | ||
model_save_path = Path(cache_dir) / "pretrained" / test_setup.model_name | ||
assert test_setup.model_factory_mock.mock_calls == [ | ||
call.from_pretrained(test_setup.model_name, cache_dir=Path(cache_dir)/"cache", | ||
use_auth_token=test_setup.token), | ||
call.from_pretrained().save_pretrained(model_save_path)] | ||
|
||
|
||
def test_upload_function_call(): | ||
test_setup = TestSetup() | ||
test_setup.downloader.download_from_huggingface_hub(model_factory=test_setup.model_factory_mock) | ||
test_setup.reset_mocks() | ||
cache_dir = mock_cast(test_setup.temporary_directory_factory_mock.create().__enter__).return_value | ||
model_save_path = Path(cache_dir) / "pretrained" / test_setup.model_name | ||
test_setup.downloader.upload_to_bucketfs() | ||
assert mock_cast(test_setup.bucketfs_model_uploader_mock.upload_directory).mock_calls == [call(model_save_path)] |