Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added class for loading locally saved model #179

Merged
merged 9 commits into from
Jan 30, 2024
1 change: 1 addition & 0 deletions doc/changes/changes_0.8.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This release added the get_language_definition function to the LanguageContainer

- #174: Added get_language_definition to the language container deployer


### Bug Fixes

- n/a
Expand Down
2 changes: 1 addition & 1 deletion doc/changes/changes_0.9.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ T.B.D

### Features

- n/a
- #145: Added load function for loading local models

### Bug Fixes

Expand Down
66 changes: 66 additions & 0 deletions exasol_transformers_extension/utils/load_local_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import transformers.pipelines
from typing import Optional
from pathlib import Path
from exasol_transformers_extension.utils.model_factory_protocol import ModelFactoryProtocol


class LoadLocalModel:
"""
Class for loading locally saved models and tokenizers. Also stores information regarding the model and pipeline.

:pipeline_factory: a function to create a transformers pipeline
:task_name: name of the current task
:device: device to be used for pipeline creation
:base_model_factory: a ModelFactoryProtocol for creating the loaded model
:tokenizer_factory: a ModelFactoryProtocol for creating the loaded tokenizer
"""
def __init__(self,
pipeline_factory,
task_name: str,
device: str,
base_model_factory: ModelFactoryProtocol,
tokenizer_factory: ModelFactoryProtocol
):
self.pipeline_factory = pipeline_factory
tkilias marked this conversation as resolved.
Show resolved Hide resolved
tkilias marked this conversation as resolved.
Show resolved Hide resolved
self.task_name = task_name
self.device = device
self.base_model_factory = base_model_factory
self.tokenizer_factory = tokenizer_factory
tkilias marked this conversation as resolved.
Show resolved Hide resolved
self._loaded_model_key = None

@property
def loaded_model_key(self):
"""Get the current loaded_model_key."""
return self._loaded_model_key

def load_models(self,
model_path: Path,
current_model_key: str
) -> transformers.pipelines.Pipeline:
"""
Loads a locally saved model and tokenizer from "cache_dir / "pretrained" / model_name".
Returns new pipeline corresponding to the model and task.

:model_path: location of the saved model and tokenizer
:current_model_key: key of the model to be loaded
"""

loaded_model = self.base_model_factory.from_pretrained(str(model_path))
loaded_tokenizer = self.tokenizer_factory.from_pretrained(str(model_path))

last_created_pipeline = self.pipeline_factory(
self.task_name,
model=loaded_model,
tokenizer=loaded_tokenizer,
device=self.device,
framework="pt")
self._loaded_model_key = current_model_key
return last_created_pipeline

def clear_device_memory(self):
"""
Delete models and free device memory
"""
torch.cuda.empty_cache()

15 changes: 12 additions & 3 deletions exasol_transformers_extension/utils/model_factory_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Protocol, Union, runtime_checkable
from typing import Protocol, Union, runtime_checkable, Optional

import transformers

Expand All @@ -9,8 +9,17 @@ class ModelFactoryProtocol(Protocol):
"""
Protocol for better type hints.
"""
def from_pretrained(self, model_name: str, cache_dir: Path, use_auth_token: str) -> transformers.PreTrainedModel:
def from_pretrained(self, model_name: str, cache_dir: Optional[Path]=None, use_auth_token: Optional[str]=None) \
-> transformers.PreTrainedModel:
"""
Either downloads a model from Huggingface Hub(all parameters required),
or loads a locally saved model from file (only requires filepath)

:model_name: model name, or path to locally saved model files
:cache_dir: optional. Path where downloaded model should be cached
:use_auth_token: optional. token for Huggingface hub private models
"""
pass

def save_pretrained(self, save_directory: Union[str, Path]):
pass
pass
83 changes: 83 additions & 0 deletions tests/integration_tests/without_db/utils/test_load_local_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from pathlib import Path, PurePosixPath
from transformers import AutoModel, AutoTokenizer
import tarfile

from exasol_transformers_extension.utils.load_local_model import LoadLocalModel
from exasol_transformers_extension.utils.model_factory_protocol import ModelFactoryProtocol
from exasol_transformers_extension.utils.huggingface_hub_bucketfs_model_transfer_sp import \
HuggingFaceHubBucketFSModelTransferSPFactory
from exasol_bucketfs_utils_python.localfs_mock_bucketfs_location import \
LocalFSMockBucketFSLocation

from tests.utils.parameters import model_params

import tempfile


class TestSetup:
def __init__(self):

self.base_model_factory: ModelFactoryProtocol = AutoModel
self.tokenizer_factory: ModelFactoryProtocol = AutoTokenizer

self.token = "token"
model_params_ = model_params.tiny_model
self.model_name = model_params_

self.mock_current_model_key = None
mock_pipeline = lambda task_name, model, tokenizer, device, framework: None
tkilias marked this conversation as resolved.
Show resolved Hide resolved
self.loader = LoadLocalModel(
mock_pipeline,
task_name="test_task",
device=0,
base_model_factory=self.base_model_factory,
tokenizer_factory=self.tokenizer_factory
)


def download_model_with_huggingface_transfer(test_setup, mock_bucketfs_location):
model_transfer_factory = HuggingFaceHubBucketFSModelTransferSPFactory()
downloader = model_transfer_factory.create(bucketfs_location=mock_bucketfs_location,
model_name=test_setup.model_name,
model_path=Path("cached_files"),
token="")
downloader.download_from_huggingface_hub(test_setup.base_model_factory)
downloader.download_from_huggingface_hub(test_setup.tokenizer_factory)
bucketfs_model_path = downloader.upload_to_bucketfs()

with tarfile.open(mock_bucketfs_location.base_path / bucketfs_model_path) as tar:
tar.extractall(path=mock_bucketfs_location.base_path / bucketfs_model_path.parent)
return mock_bucketfs_location.base_path / bucketfs_model_path.parent


def test_load_local_model():
test_setup = TestSetup()

with tempfile.TemporaryDirectory() as dir:
dir_p = Path(dir)
model_save_path = dir_p / "pretrained" / test_setup.model_name
# download a model
model = AutoModel.from_pretrained(test_setup.model_name)
tokenizer = AutoTokenizer.from_pretrained(test_setup.model_name)
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)
MarleneKress79789 marked this conversation as resolved.
Show resolved Hide resolved

test_setup.loader.load_models(current_model_key=test_setup.mock_current_model_key,
model_path=dir_p / "pretrained" / test_setup.model_name)


def test_load_local_model_with_huggingface_model_transfer():
test_setup = TestSetup()

with tempfile.TemporaryDirectory() as dire:
dir_p = Path(dire)

mock_bucketfs_location = LocalFSMockBucketFSLocation(
PurePosixPath(dir_p / "bucket"))

# download a model
downloaded_model_path = download_model_with_huggingface_transfer(
test_setup, mock_bucketfs_location)

test_setup.loader.load_models(current_model_key=test_setup.mock_current_model_key,
model_path=downloaded_model_path)
45 changes: 45 additions & 0 deletions tests/unit_tests/utils/test_load_local_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import tempfile
from pathlib import Path
from typing import Union
from unittest.mock import create_autospec, MagicMock, call

from exasol_transformers_extension.utils.model_factory_protocol import ModelFactoryProtocol
from exasol_transformers_extension.utils.load_local_model import LoadLocalModel

from tests.utils.parameters import model_params


class TestSetup:
def __init__(self):

self.model_factory_mock: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)
self.tokenizer_factory_mock: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)
self.token = "token"
model_params_ = model_params.tiny_model
self.model_name = model_params_

mock_pipeline = lambda task_name, model, tokenizer, device, framework: None
tkilias marked this conversation as resolved.
Show resolved Hide resolved
self.loader = LoadLocalModel(
mock_pipeline,
task_name="test_task",
device=0,
base_model_factory=self.model_factory_mock,
tokenizer_factory=self.tokenizer_factory_mock)


def test_load_function_call():
test_setup = TestSetup()
mock_current_model_key = "some_key"
with tempfile.TemporaryDirectory() as dir:
tkilias marked this conversation as resolved.
Show resolved Hide resolved
dir_p = Path(dir)
cache_dir = dir_p
model_save_path = Path(cache_dir) / "pretrained" / test_setup.model_name

test_setup.loader.load_models(current_model_key=mock_current_model_key,
model_path=model_save_path)

assert test_setup.model_factory_mock.mock_calls == [
call.from_pretrained(str(model_save_path))]
assert test_setup.tokenizer_factory_mock.mock_calls == [
call.from_pretrained(str(model_save_path))]

Loading