Skip to content

Commit

Permalink
added Huggingface model transfer class with save_pretrained in download
Browse files Browse the repository at this point in the history
  • Loading branch information
MarleneKress79789 committed Nov 29, 2023
1 parent 42ae75c commit c4f22e0
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import tempfile
from pathlib import Path
from typing import Protocol, Union, runtime_checkable

import transformers
from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation

from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploaderFactory
from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory


@runtime_checkable
class ModelFactoryProtocol(Protocol):
def from_pretrained(self, model_name: str, cache_dir: Path, use_auth_token: str) -> transformers.PreTrainedModel:
pass

def save_pretrained(self, save_directory: Union[str, Path]):
pass


class HuggingFaceHubBucketFSModelTransferSP:
def __init__(self,
bucketfs_location: BucketFSLocation,
model_name: str,
model_path: Path,
local_model_save_path: Path,
token: str,
temporary_directory_factory: TemporaryDirectoryFactory = TemporaryDirectoryFactory(),
bucketfs_model_uploader_factory: BucketFSModelUploaderFactory = BucketFSModelUploaderFactory()):
self._token = token
self._model_name = model_name
self._local_model_save_path = Path(local_model_save_path)
self._temporary_directory_factory = temporary_directory_factory
self._bucketfs_model_uploader = bucketfs_model_uploader_factory.create(
model_path=model_path,
bucketfs_location=bucketfs_location)
self._tmpdir = temporary_directory_factory.create()
self._tmpdir_name = self._tmpdir.__enter__()

def __enter__(self):
return self

def __del__(self):
self._tmpdir.cleanup()

def __exit__(self, exc_type, exc_val, exc_tb):
self._tmpdir.__exit__(exc_type, exc_val, exc_tb)

def download_from_huggingface_hub_sp(self, model_factory: ModelFactoryProtocol):
"""
Download a model from HuggingFace Hub into a temporary directory and save it with save_pretrained
at _local_model_save_path / _model_name for local storing
"""
model = model_factory.from_pretrained(self._model_name, cache_dir=self._tmpdir_name, use_auth_token=self._token)
path = self._local_model_save_path / self._model_name
model.save_pretrained(path) #todo save in cachedir in assuption will be uploaded and then deleted?

def upload_to_bucketfs(self) -> Path:
"""
Upload the downloaded models into the BucketFS
"""
return self._bucketfs_model_uploader.upload_directory(self._tmpdir_name)


class HuggingFaceHubBucketFSModelTransferSPFactory:

def create(self,
bucketfs_location: BucketFSLocation,
model_name: str,
model_path: Path,
local_model_save_path: Path,
token: str) -> HuggingFaceHubBucketFSModelTransferSP:
return HuggingFaceHubBucketFSModelTransferSP(bucketfs_location=bucketfs_location,
model_name=model_name,
model_path=model_path,
local_model_save_path=local_model_save_path,
token=token)
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import tempfile
from pathlib import Path
from typing import Union
from unittest.mock import create_autospec, MagicMock, call

from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation
from transformers import AutoModel, PreTrainedModel

from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploader, \
BucketFSModelUploaderFactory
from exasol_transformers_extension.utils.huggingface_hub_bucketfs_model_transfer_sp import ModelFactoryProtocol, \
HuggingFaceHubBucketFSModelTransferSP
from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory
from tests.utils.mock_cast import mock_cast

from tests.utils.parameters import model_params

class TestSetup:
def __init__(self, local_model_save_path: Path = "downloaded_models_test"):
self.bucketfs_location_mock: Union[BucketFSLocation, MagicMock] = create_autospec(BucketFSLocation)
self.model_factory_mock: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)
self.temporary_directory_factory_mock: Union[TemporaryDirectoryFactory, MagicMock] = \
create_autospec(TemporaryDirectoryFactory)
self.bucketfs_model_uploader_factory_mock: Union[BucketFSModelUploaderFactory, MagicMock] = \
create_autospec(BucketFSModelUploaderFactory)
self.bucketfs_model_uploader_mock: Union[BucketFSModelUploader, MagicMock] = \
create_autospec(BucketFSModelUploader)
mock_cast(self.bucketfs_model_uploader_factory_mock.create).side_effect = [self.bucketfs_model_uploader_mock]

self.token = "token"
model_params_ = model_params.tiny_model
print(model_params_)
self.model_name = model_params_
self.model_path = Path("test_model_path")
self.downloader = HuggingFaceHubBucketFSModelTransferSP(
bucketfs_location=self.bucketfs_location_mock,
model_path=self.model_path,
model_name=self.model_name,
local_model_save_path=local_model_save_path,
token=self.token,
temporary_directory_factory=self.temporary_directory_factory_mock,
bucketfs_model_uploader_factory=self.bucketfs_model_uploader_factory_mock
)

def reset_mocks(self):
self.bucketfs_location_mock.reset_mock()
self.temporary_directory_factory_mock.reset_mock()
self.model_factory_mock.reset_mock()
self.bucketfs_model_uploader_mock.reset_mock()


def test_init():
test_setup = TestSetup()
assert test_setup.temporary_directory_factory_mock.mock_calls == [call.create(), call.create().__enter__()] \
and test_setup.model_factory_mock.mock_calls == [] \
and test_setup.bucketfs_location_mock.mock_calls == [] \
and mock_cast(test_setup.bucketfs_model_uploader_factory_mock.create).mock_calls == [
call.create(model_path=test_setup.model_path, bucketfs_location=test_setup.bucketfs_location_mock)
]


def test_download_function_call():
test_setup = TestSetup()
test_setup.downloader.download_from_huggingface_hub_sp(model_factory=test_setup.model_factory_mock)
cache_dir = test_setup.temporary_directory_factory_mock.create().__enter__()
model_save_path = (test_setup.downloader._local_model_save_path/test_setup.model_name)
assert test_setup.model_factory_mock.mock_calls == [
call.from_pretrained(test_setup.model_name, cache_dir=cache_dir,
use_auth_token=test_setup.token),
call.from_pretrained().save_pretrained(model_save_path)]


# todo add test for model already downloaded?

def test_download_with_model():
with tempfile.TemporaryDirectory() as folder:
folder_path = Path(folder)
test_setup = TestSetup(local_model_save_path=folder_path/"downloaded_models")
base_model_factory: ModelFactoryProtocol = AutoModel
test_setup.downloader.download_from_huggingface_hub_sp(model_factory=base_model_factory)
assert AutoModel.from_pretrained(folder_path/"downloaded_models"/test_setup.model_name)
test_setup.downloader.__del__()
#todo delete model


0 comments on commit c4f22e0

Please sign in to comment.