Skip to content

Commit

Permalink
Moved model loader to seperate class, injected into basemodel udf and…
Browse files Browse the repository at this point in the history
… changed test setup accordingly
  • Loading branch information
MarleneKress79789 committed Dec 1, 2023
1 parent 360297c commit 987a1ef
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 14 deletions.
18 changes: 12 additions & 6 deletions exasol_transformers_extension/udfs/models/base_model_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class BaseModelUDF(ABC):
- creates model pipeline through transformer api
- manages the creation of predictions and the preparation of results.
"""
# todo does the token con change? (if yes need to be give at function call not class creation)
def __init__(self,
exa,
batch_size,
Expand All @@ -43,11 +42,7 @@ def __init__(self,
def run(self, ctx):
device_id = ctx.get_dataframe(1).iloc[0]['device_id']
self.device = device_management.get_torch_device(device_id)
self.model_loader = LoadModel(self.pipeline,
self.base_model,
self.tokenizer,
self.task_name,
self.device)
self.create_model_loader()
ctx.reset()

while True:
Expand All @@ -59,6 +54,17 @@ def run(self, ctx):

self.clear_device_memory()

def create_model_loader(self):
"""
creates the model_loader. In separate function, so it can be replaced for tests since the pipeline
creation does not work with dummy data
"""
self.model_loader = LoadModel(self.pipeline,
self.base_model,
self.tokenizer,
self.task_name,
self.device)

def get_predictions_from_batch(self, batch_df: pd.DataFrame) -> pd.DataFrame:
"""
Perform separate predictions for each model in the dataframe.
Expand Down
43 changes: 37 additions & 6 deletions tests/unit_tests/udfs/base_model_dummy_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,37 @@
BaseModelUDF


class DummyModelLoader:
"""
Create a Dummy model loader that does not create a transformers Pipeline,
since that fails with test data.
"""
def __init__(self,
base_model,
tokenizer,
task_name,
device
):
self.base_model = base_model
self.tokenizer = tokenizer
self.task_name = task_name
self.device = device
self.last_loaded_model = None
self.last_loaded_tokenizer = None
self.last_created_pipeline = None
self.last_loaded_model_key = None

def load_models(self, model_name: str,
current_model_key,
cache_dir,
token_conn_obj) -> None:
token = False
self.last_loaded_model = self.base_model.from_pretrained(
model_name, cache_dir=cache_dir, use_auth_token=token)
self.last_loaded_tokenizer = self.tokenizer.from_pretrained(
model_name, cache_dir=cache_dir, use_auth_token=token)


class DummyImplementationUDF(BaseModelUDF):
def __init__(self,
exa,
Expand Down Expand Up @@ -44,9 +75,9 @@ def create_dataframes_from_predictions(
results_df_list.append(result_df)
return results_df_list

def load_models(self, model_name: str, token_conn_name: str) -> None:
token = False
self.last_loaded_model = self.base_model.from_pretrained(
model_name, cache_dir=self.cache_dir, use_auth_token=token)
self.last_loaded_tokenizer = self.tokenizer.from_pretrained(
model_name, cache_dir=self.cache_dir, use_auth_token=token)
def create_model_loader(self):
""" overwrite the model loader creation with dummy model loader creation"""
self.model_loader = DummyModelLoader(self.base_model,
self.tokenizer,
self.task_name,
self.device)
5 changes: 3 additions & 2 deletions tests/unit_tests/udfs/test_base_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tests.unit_tests.utils_for_udf_tests import create_mock_exa_environment, create_mock_udf_context
from tests.unit_tests.udfs.base_model_dummy_implementation import DummyImplementationUDF
from exasol_transformers_extension.utils.huggingface_hub_bucketfs_model_transfer import ModelFactoryProtocol
from exasol_transformers_extension.utils.load_model import LoadModel
from tests.utils.mock_cast import mock_cast
import re

Expand Down Expand Up @@ -80,8 +81,8 @@ def setup_tests_and_run(bucketfs_conn_name, bucketfs_conn, sub_dir, model_name):
None)
mock_ctx = create_mock_udf_context(input_data, mock_meta)
udf = DummyImplementationUDF(exa=mock_exa,
base_model=mock_base_model_factory,
tokenizer=mock_tokenizer_factory)
base_model=mock_base_model_factory,
tokenizer=mock_tokenizer_factory)
udf.run(mock_ctx)
res = mock_ctx.output
return res, mock_meta
Expand Down

0 comments on commit 987a1ef

Please sign in to comment.