Skip to content

Commit

Permalink
[CodeBuild] changes from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
MarleneKress79789 committed Dec 19, 2023
1 parent cb4dc19 commit 8be8f4a
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 37 deletions.
7 changes: 3 additions & 4 deletions doc/changes/changes_0.7.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ T.B.D

### Bug Fixes


### Bug Fixes
- n/a

### Refactorings

- #144: Extracted base_model_udf.load_models into separate class



### Documentation

- n/a

### Security

Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import tempfile
from pathlib import Path
from typing import Protocol, runtime_checkable


from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation

from exasol_transformers_extension.utils.model_factory_protocol import ModelFactoryProtocol
from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploaderFactory
from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory


@runtime_checkable
class ModelFactoryProtocol(Protocol):
def from_pretrained(self, model_name: str, cache_dir: Path, use_auth_token: str):
pass


class HuggingFaceHubBucketFSModelTransfer:

def __init__(self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,19 @@
from pathlib import Path
from typing import Protocol, Union, runtime_checkable

import transformers
from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation

from exasol_transformers_extension.utils.model_factory_protocol import ModelFactoryProtocol
from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploaderFactory
from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory


@runtime_checkable
class ModelFactoryProtocol(Protocol):
"""
Protocol for better type hints.
"""
def from_pretrained(self, model_name: str, cache_dir: Path, use_auth_token: str) -> transformers.PreTrainedModel:
pass

def save_pretrained(self, save_directory: Union[str, Path]):
pass


class HuggingFaceHubBucketFSModelTransferSP:
"""
Class for downloading a model using the Huggingface Transformers API, and loading it into the BucketFS.
Class for downloading a model using the Huggingface Transformers API, and loading it into the BucketFS
using save_pretrained.
:bucketfs_location: BucketFSLocation the model should be loaded to
:model_name: Name of the model to be downloaded using Huggingface Transformers API
Expand Down Expand Up @@ -61,16 +52,16 @@ def download_from_huggingface_hub(self, model_factory: ModelFactoryProtocol):
Download a model from HuggingFace Hub into a temporary directory and save it with save_pretrained
in temporary directory / pretrained .
"""
model = model_factory.from_pretrained(self._model_name, cache_dir=self._tmpdir_name/"cache", use_auth_token=self._token)
model.save_pretrained(self._tmpdir_name/"pretrained"/self._model_name)
model = model_factory.from_pretrained(self._model_name, cache_dir=self._tmpdir_name / "cache", use_auth_token=self._token)
model.save_pretrained(self._tmpdir_name / "pretrained" / self._model_name)

def upload_to_bucketfs(self) -> Path:
"""
Upload the downloaded models into the BucketFS.
returns: Path of the uploaded model in the BucketFS
"""
return self._bucketfs_model_uploader.upload_directory(self._tmpdir_name/"pretrained"/self._model_name)
return self._bucketfs_model_uploader.upload_directory(self._tmpdir_name / "pretrained" / self._model_name)


class HuggingFaceHubBucketFSModelTransferSPFactory:
Expand Down
16 changes: 16 additions & 0 deletions exasol_transformers_extension/utils/model_factory_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pathlib import Path
from typing import Protocol, Union, runtime_checkable

import transformers


@runtime_checkable
class ModelFactoryProtocol(Protocol):
"""
Protocol for better type hints.
"""
def from_pretrained(self, model_name: str, cache_dir: Path, use_auth_token: str) -> transformers.PreTrainedModel:
pass

def save_pretrained(self, save_directory: Union[str, Path]):
pass
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "exasol-transformers-extension"
version = "0.6.0"
version = "0.7.0"
description = "An Exasol extension to use state-of-the-art pretrained machine learning models via the transformers api."

authors = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(self, bucketfs_location):

self.token = "token"
model_params_ = model_params.tiny_model
print(model_params_)
self.model_name = model_params_
self.model_path = Path("test_model_path")
self.downloader = HuggingFaceHubBucketFSModelTransferSP(
Expand All @@ -51,8 +50,8 @@ def test_download_with_model(bucketfs_location):
test_setup = TestSetup(bucketfs_location)
base_model_factory: ModelFactoryProtocol = AutoModel
test_setup.downloader.download_from_huggingface_hub(model_factory=base_model_factory)
assert AutoModel.from_pretrained(test_setup.downloader._tmpdir_name/"pretrained"/test_setup.model_name)
test_setup.downloader.__del__()
assert AutoModel.from_pretrained(test_setup.downloader._tmpdir_name / "pretrained" / test_setup.model_name)
del test_setup.downloader


def test_download_with_duplicate_model(bucketfs_location):
Expand All @@ -61,5 +60,5 @@ def test_download_with_duplicate_model(bucketfs_location):
base_model_factory: ModelFactoryProtocol = AutoModel
test_setup.downloader.download_from_huggingface_hub(model_factory=base_model_factory)
test_setup.downloader.download_from_huggingface_hub(model_factory=base_model_factory)
assert AutoModel.from_pretrained(test_setup.downloader._tmpdir_name/"pretrained"/test_setup.model_name)
test_setup.downloader.__del__()
assert AutoModel.from_pretrained(test_setup.downloader._tmpdir_name / "pretrained" / test_setup.model_name)
del test_setup.downloader
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def __init__(self):
create_autospec(BucketFSModelUploader)
mock_cast(self.bucketfs_model_uploader_factory_mock.create).side_effect = [self.bucketfs_model_uploader_mock]


self.token = "token"
model_params_ = model_params.tiny_model
print(model_params_)
self.model_name = model_params_
self.model_path = Path("test_model_path")
self.downloader = HuggingFaceHubBucketFSModelTransferSP(
Expand All @@ -45,6 +45,7 @@ def reset_mocks(self):
self.temporary_directory_factory_mock.reset_mock()
self.model_factory_mock.reset_mock()
self.bucketfs_model_uploader_mock.reset_mock()
self.bucketfs_model_uploader_factory_mock.reset_mock()


def test_init():
Expand All @@ -62,8 +63,8 @@ def test_init():
def test_download_function_call():
test_setup = TestSetup()
test_setup.downloader.download_from_huggingface_hub(model_factory=test_setup.model_factory_mock)
cache_dir = test_setup.temporary_directory_factory_mock.create().__enter__()
model_save_path = (test_setup.downloader._tmpdir_name/"pretrained"/test_setup.model_name)
cache_dir = mock_cast(test_setup.temporary_directory_factory_mock.create().__enter__).return_value
model_save_path = Path(cache_dir) / "pretrained" / test_setup.model_name
assert test_setup.model_factory_mock.mock_calls == [
call.from_pretrained(test_setup.model_name, cache_dir=Path(cache_dir)/"cache",
use_auth_token=test_setup.token),
Expand All @@ -74,6 +75,7 @@ def test_upload_function_call():
test_setup = TestSetup()
test_setup.downloader.download_from_huggingface_hub(model_factory=test_setup.model_factory_mock)
test_setup.reset_mocks()
model_save_path = (test_setup.downloader._tmpdir_name/"pretrained"/test_setup.model_name)
cache_dir = mock_cast(test_setup.temporary_directory_factory_mock.create().__enter__).return_value
model_save_path = Path(cache_dir) / "pretrained" / test_setup.model_name
test_setup.downloader.upload_to_bucketfs()
assert mock_cast(test_setup.bucketfs_model_uploader_mock.upload_directory).mock_calls == [call(model_save_path)]

0 comments on commit 8be8f4a

Please sign in to comment.