From bbf3505336af1a9efe6fbadac18e67c223d11e95 Mon Sep 17 00:00:00 2001 From: MarleneKress79789 Date: Fri, 1 Dec 2023 15:04:06 +0100 Subject: [PATCH] moved last_created_pipeline back --- .../udfs/models/base_model_udf.py | 11 +++++++++-- exasol_transformers_extension/utils/load_model.py | 4 ++-- .../udfs/base_model_dummy_implementation.py | 1 + 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/exasol_transformers_extension/udfs/models/base_model_udf.py b/exasol_transformers_extension/udfs/models/base_model_udf.py index e757c309..b8706cf8 100644 --- a/exasol_transformers_extension/udfs/models/base_model_udf.py +++ b/exasol_transformers_extension/udfs/models/base_model_udf.py @@ -37,6 +37,7 @@ def __init__(self, self.device = None self.cache_dir = None self.model_loader = None + self.last_created_pipeline = None self.new_columns = [] def run(self, ctx): @@ -183,7 +184,14 @@ def check_cache(self, model_df: pd.DataFrame) -> None: if self.model_loader.last_loaded_model_key != current_model_key: self.set_cache_dir(model_name, bucketfs_conn, sub_dir) self.clear_device_memory() - self.model_loader.load_models(model_name, current_model_key, self.cache_dir, self.exa.get_connection(token_conn)) + if token_conn: + token_conn_obj = self.exa.get_connection(token_conn) + else: + token_conn_obj = None + self.last_created_pipeline = self.model_loader.load_models(model_name, + current_model_key, + self.cache_dir, + token_conn_obj) def set_cache_dir( self, model_name: str, bucketfs_conn_name: str, @@ -203,7 +211,6 @@ def set_cache_dir( self.cache_dir = bucketfs_operations.get_local_bucketfs_path( bucketfs_location=bucketfs_location, model_path=str(model_path)) - # todo move this also? def clear_device_memory(self): """ Delete models and free device memory diff --git a/exasol_transformers_extension/utils/load_model.py b/exasol_transformers_extension/utils/load_model.py index 226d2e5f..ac4033a2 100644 --- a/exasol_transformers_extension/utils/load_model.py +++ b/exasol_transformers_extension/utils/load_model.py @@ -15,7 +15,6 @@ def __init__(self, self.device = device self.last_loaded_model = None self.last_loaded_tokenizer = None - self.last_created_pipeline = None self.last_loaded_model_key = None def load_models(self, model_name: str, @@ -39,10 +38,11 @@ def load_models(self, model_name: str, model_name, cache_dir=cache_dir, use_auth_token=token) self.last_loaded_tokenizer = self.tokenizer.from_pretrained( model_name, cache_dir=cache_dir, use_auth_token=token) - self.last_created_pipeline = self.pipeline( + last_created_pipeline = self.pipeline( self.task_name, model=self.last_loaded_model, tokenizer=self.last_loaded_tokenizer, device=self.device, framework="pt") self.last_loaded_model_key = current_model_key + return last_created_pipeline diff --git a/tests/unit_tests/udfs/base_model_dummy_implementation.py b/tests/unit_tests/udfs/base_model_dummy_implementation.py index 5d103411..a16e3abe 100644 --- a/tests/unit_tests/udfs/base_model_dummy_implementation.py +++ b/tests/unit_tests/udfs/base_model_dummy_implementation.py @@ -35,6 +35,7 @@ def load_models(self, model_name: str, model_name, cache_dir=cache_dir, use_auth_token=token) self.last_loaded_tokenizer = self.tokenizer.from_pretrained( model_name, cache_dir=cache_dir, use_auth_token=token) + return None class DummyImplementationUDF(BaseModelUDF):