From 987a1ef83ab450d6eb4761f31137930c43dc76dd Mon Sep 17 00:00:00 2001 From: MarleneKress79789 Date: Fri, 1 Dec 2023 13:47:14 +0100 Subject: [PATCH] Moved model loader to seperate class, injected into basemodel udf and changed test setup accordingly --- .../udfs/models/base_model_udf.py | 18 +++++--- .../udfs/base_model_dummy_implementation.py | 43 ++++++++++++++++--- tests/unit_tests/udfs/test_base_udf.py | 5 ++- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/exasol_transformers_extension/udfs/models/base_model_udf.py b/exasol_transformers_extension/udfs/models/base_model_udf.py index 772e3b85..e757c309 100644 --- a/exasol_transformers_extension/udfs/models/base_model_udf.py +++ b/exasol_transformers_extension/udfs/models/base_model_udf.py @@ -21,7 +21,6 @@ class BaseModelUDF(ABC): - creates model pipeline through transformer api - manages the creation of predictions and the preparation of results. """ -# todo does the token con change? (if yes need to be give at function call not class creation) def __init__(self, exa, batch_size, @@ -43,11 +42,7 @@ def __init__(self, def run(self, ctx): device_id = ctx.get_dataframe(1).iloc[0]['device_id'] self.device = device_management.get_torch_device(device_id) - self.model_loader = LoadModel(self.pipeline, - self.base_model, - self.tokenizer, - self.task_name, - self.device) + self.create_model_loader() ctx.reset() while True: @@ -59,6 +54,17 @@ def run(self, ctx): self.clear_device_memory() + def create_model_loader(self): + """ + creates the model_loader. In separate function, so it can be replaced for tests since the pipeline + creation does not work with dummy data + """ + self.model_loader = LoadModel(self.pipeline, + self.base_model, + self.tokenizer, + self.task_name, + self.device) + def get_predictions_from_batch(self, batch_df: pd.DataFrame) -> pd.DataFrame: """ Perform separate predictions for each model in the dataframe. diff --git a/tests/unit_tests/udfs/base_model_dummy_implementation.py b/tests/unit_tests/udfs/base_model_dummy_implementation.py index 801549ab..5d103411 100644 --- a/tests/unit_tests/udfs/base_model_dummy_implementation.py +++ b/tests/unit_tests/udfs/base_model_dummy_implementation.py @@ -6,6 +6,37 @@ BaseModelUDF +class DummyModelLoader: + """ + Create a Dummy model loader that does not create a transformers Pipeline, + since that fails with test data. + """ + def __init__(self, + base_model, + tokenizer, + task_name, + device + ): + self.base_model = base_model + self.tokenizer = tokenizer + self.task_name = task_name + self.device = device + self.last_loaded_model = None + self.last_loaded_tokenizer = None + self.last_created_pipeline = None + self.last_loaded_model_key = None + + def load_models(self, model_name: str, + current_model_key, + cache_dir, + token_conn_obj) -> None: + token = False + self.last_loaded_model = self.base_model.from_pretrained( + model_name, cache_dir=cache_dir, use_auth_token=token) + self.last_loaded_tokenizer = self.tokenizer.from_pretrained( + model_name, cache_dir=cache_dir, use_auth_token=token) + + class DummyImplementationUDF(BaseModelUDF): def __init__(self, exa, @@ -44,9 +75,9 @@ def create_dataframes_from_predictions( results_df_list.append(result_df) return results_df_list - def load_models(self, model_name: str, token_conn_name: str) -> None: - token = False - self.last_loaded_model = self.base_model.from_pretrained( - model_name, cache_dir=self.cache_dir, use_auth_token=token) - self.last_loaded_tokenizer = self.tokenizer.from_pretrained( - model_name, cache_dir=self.cache_dir, use_auth_token=token) \ No newline at end of file + def create_model_loader(self): + """ overwrite the model loader creation with dummy model loader creation""" + self.model_loader = DummyModelLoader(self.base_model, + self.tokenizer, + self.task_name, + self.device) diff --git a/tests/unit_tests/udfs/test_base_udf.py b/tests/unit_tests/udfs/test_base_udf.py index 6f8328fc..6e8dc4c1 100644 --- a/tests/unit_tests/udfs/test_base_udf.py +++ b/tests/unit_tests/udfs/test_base_udf.py @@ -10,6 +10,7 @@ from tests.unit_tests.utils_for_udf_tests import create_mock_exa_environment, create_mock_udf_context from tests.unit_tests.udfs.base_model_dummy_implementation import DummyImplementationUDF from exasol_transformers_extension.utils.huggingface_hub_bucketfs_model_transfer import ModelFactoryProtocol +from exasol_transformers_extension.utils.load_model import LoadModel from tests.utils.mock_cast import mock_cast import re @@ -80,8 +81,8 @@ def setup_tests_and_run(bucketfs_conn_name, bucketfs_conn, sub_dir, model_name): None) mock_ctx = create_mock_udf_context(input_data, mock_meta) udf = DummyImplementationUDF(exa=mock_exa, - base_model=mock_base_model_factory, - tokenizer=mock_tokenizer_factory) + base_model=mock_base_model_factory, + tokenizer=mock_tokenizer_factory) udf.run(mock_ctx) res = mock_ctx.output return res, mock_meta