-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added Huggingface model transfer class with save_pretrained in download
- Loading branch information
1 parent
42ae75c
commit c4f22e0
Showing
2 changed files
with
163 additions
and
0 deletions.
There are no files selected for viewing
78 changes: 78 additions & 0 deletions
78
exasol_transformers_extension/utils/huggingface_hub_bucketfs_model_transfer_sp.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import os | ||
import tempfile | ||
from pathlib import Path | ||
from typing import Protocol, Union, runtime_checkable | ||
|
||
import transformers | ||
from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation | ||
|
||
from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploaderFactory | ||
from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory | ||
|
||
|
||
@runtime_checkable | ||
class ModelFactoryProtocol(Protocol): | ||
def from_pretrained(self, model_name: str, cache_dir: Path, use_auth_token: str) -> transformers.PreTrainedModel: | ||
pass | ||
|
||
def save_pretrained(self, save_directory: Union[str, Path]): | ||
pass | ||
|
||
|
||
class HuggingFaceHubBucketFSModelTransferSP: | ||
def __init__(self, | ||
bucketfs_location: BucketFSLocation, | ||
model_name: str, | ||
model_path: Path, | ||
local_model_save_path: Path, | ||
token: str, | ||
temporary_directory_factory: TemporaryDirectoryFactory = TemporaryDirectoryFactory(), | ||
bucketfs_model_uploader_factory: BucketFSModelUploaderFactory = BucketFSModelUploaderFactory()): | ||
self._token = token | ||
self._model_name = model_name | ||
self._local_model_save_path = Path(local_model_save_path) | ||
self._temporary_directory_factory = temporary_directory_factory | ||
self._bucketfs_model_uploader = bucketfs_model_uploader_factory.create( | ||
model_path=model_path, | ||
bucketfs_location=bucketfs_location) | ||
self._tmpdir = temporary_directory_factory.create() | ||
self._tmpdir_name = self._tmpdir.__enter__() | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __del__(self): | ||
self._tmpdir.cleanup() | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self._tmpdir.__exit__(exc_type, exc_val, exc_tb) | ||
|
||
def download_from_huggingface_hub_sp(self, model_factory: ModelFactoryProtocol): | ||
""" | ||
Download a model from HuggingFace Hub into a temporary directory and save it with save_pretrained | ||
at _local_model_save_path / _model_name for local storing | ||
""" | ||
model = model_factory.from_pretrained(self._model_name, cache_dir=self._tmpdir_name, use_auth_token=self._token) | ||
path = self._local_model_save_path / self._model_name | ||
model.save_pretrained(path) #todo save in cachedir in assuption will be uploaded and then deleted? | ||
|
||
def upload_to_bucketfs(self) -> Path: | ||
""" | ||
Upload the downloaded models into the BucketFS | ||
""" | ||
return self._bucketfs_model_uploader.upload_directory(self._tmpdir_name) | ||
|
||
|
||
class HuggingFaceHubBucketFSModelTransferSPFactory: | ||
|
||
def create(self, | ||
bucketfs_location: BucketFSLocation, | ||
model_name: str, | ||
model_path: Path, | ||
local_model_save_path: Path, | ||
token: str) -> HuggingFaceHubBucketFSModelTransferSP: | ||
return HuggingFaceHubBucketFSModelTransferSP(bucketfs_location=bucketfs_location, | ||
model_name=model_name, | ||
model_path=model_path, | ||
local_model_save_path=local_model_save_path, | ||
token=token) |
85 changes: 85 additions & 0 deletions
85
tests/unit_tests/utils/test_huggingface_hub_bucketfs__model_transfer_sp.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import tempfile | ||
from pathlib import Path | ||
from typing import Union | ||
from unittest.mock import create_autospec, MagicMock, call | ||
|
||
from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation | ||
from transformers import AutoModel, PreTrainedModel | ||
|
||
from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploader, \ | ||
BucketFSModelUploaderFactory | ||
from exasol_transformers_extension.utils.huggingface_hub_bucketfs_model_transfer_sp import ModelFactoryProtocol, \ | ||
HuggingFaceHubBucketFSModelTransferSP | ||
from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory | ||
from tests.utils.mock_cast import mock_cast | ||
|
||
from tests.utils.parameters import model_params | ||
|
||
class TestSetup: | ||
def __init__(self, local_model_save_path: Path = "downloaded_models_test"): | ||
self.bucketfs_location_mock: Union[BucketFSLocation, MagicMock] = create_autospec(BucketFSLocation) | ||
self.model_factory_mock: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol) | ||
self.temporary_directory_factory_mock: Union[TemporaryDirectoryFactory, MagicMock] = \ | ||
create_autospec(TemporaryDirectoryFactory) | ||
self.bucketfs_model_uploader_factory_mock: Union[BucketFSModelUploaderFactory, MagicMock] = \ | ||
create_autospec(BucketFSModelUploaderFactory) | ||
self.bucketfs_model_uploader_mock: Union[BucketFSModelUploader, MagicMock] = \ | ||
create_autospec(BucketFSModelUploader) | ||
mock_cast(self.bucketfs_model_uploader_factory_mock.create).side_effect = [self.bucketfs_model_uploader_mock] | ||
|
||
self.token = "token" | ||
model_params_ = model_params.tiny_model | ||
print(model_params_) | ||
self.model_name = model_params_ | ||
self.model_path = Path("test_model_path") | ||
self.downloader = HuggingFaceHubBucketFSModelTransferSP( | ||
bucketfs_location=self.bucketfs_location_mock, | ||
model_path=self.model_path, | ||
model_name=self.model_name, | ||
local_model_save_path=local_model_save_path, | ||
token=self.token, | ||
temporary_directory_factory=self.temporary_directory_factory_mock, | ||
bucketfs_model_uploader_factory=self.bucketfs_model_uploader_factory_mock | ||
) | ||
|
||
def reset_mocks(self): | ||
self.bucketfs_location_mock.reset_mock() | ||
self.temporary_directory_factory_mock.reset_mock() | ||
self.model_factory_mock.reset_mock() | ||
self.bucketfs_model_uploader_mock.reset_mock() | ||
|
||
|
||
def test_init(): | ||
test_setup = TestSetup() | ||
assert test_setup.temporary_directory_factory_mock.mock_calls == [call.create(), call.create().__enter__()] \ | ||
and test_setup.model_factory_mock.mock_calls == [] \ | ||
and test_setup.bucketfs_location_mock.mock_calls == [] \ | ||
and mock_cast(test_setup.bucketfs_model_uploader_factory_mock.create).mock_calls == [ | ||
call.create(model_path=test_setup.model_path, bucketfs_location=test_setup.bucketfs_location_mock) | ||
] | ||
|
||
|
||
def test_download_function_call(): | ||
test_setup = TestSetup() | ||
test_setup.downloader.download_from_huggingface_hub_sp(model_factory=test_setup.model_factory_mock) | ||
cache_dir = test_setup.temporary_directory_factory_mock.create().__enter__() | ||
model_save_path = (test_setup.downloader._local_model_save_path/test_setup.model_name) | ||
assert test_setup.model_factory_mock.mock_calls == [ | ||
call.from_pretrained(test_setup.model_name, cache_dir=cache_dir, | ||
use_auth_token=test_setup.token), | ||
call.from_pretrained().save_pretrained(model_save_path)] | ||
|
||
|
||
# todo add test for model already downloaded? | ||
|
||
def test_download_with_model(): | ||
with tempfile.TemporaryDirectory() as folder: | ||
folder_path = Path(folder) | ||
test_setup = TestSetup(local_model_save_path=folder_path/"downloaded_models") | ||
base_model_factory: ModelFactoryProtocol = AutoModel | ||
test_setup.downloader.download_from_huggingface_hub_sp(model_factory=base_model_factory) | ||
assert AutoModel.from_pretrained(folder_path/"downloaded_models"/test_setup.model_name) | ||
test_setup.downloader.__del__() | ||
#todo delete model | ||
|
||
|