Skip to content

Commit

Permalink
moved last_created_pipeline back
Browse files Browse the repository at this point in the history
  • Loading branch information
MarleneKress79789 committed Dec 1, 2023
1 parent 987a1ef commit bbf3505
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
11 changes: 9 additions & 2 deletions exasol_transformers_extension/udfs/models/base_model_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self,
self.device = None
self.cache_dir = None
self.model_loader = None
self.last_created_pipeline = None
self.new_columns = []

def run(self, ctx):
Expand Down Expand Up @@ -183,7 +184,14 @@ def check_cache(self, model_df: pd.DataFrame) -> None:
if self.model_loader.last_loaded_model_key != current_model_key:
self.set_cache_dir(model_name, bucketfs_conn, sub_dir)
self.clear_device_memory()
self.model_loader.load_models(model_name, current_model_key, self.cache_dir, self.exa.get_connection(token_conn))
if token_conn:
token_conn_obj = self.exa.get_connection(token_conn)
else:
token_conn_obj = None
self.last_created_pipeline = self.model_loader.load_models(model_name,
current_model_key,
self.cache_dir,
token_conn_obj)

def set_cache_dir(
self, model_name: str, bucketfs_conn_name: str,
Expand All @@ -203,7 +211,6 @@ def set_cache_dir(
self.cache_dir = bucketfs_operations.get_local_bucketfs_path(
bucketfs_location=bucketfs_location, model_path=str(model_path))

# todo move this also?
def clear_device_memory(self):
"""
Delete models and free device memory
Expand Down
4 changes: 2 additions & 2 deletions exasol_transformers_extension/utils/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(self,
self.device = device
self.last_loaded_model = None
self.last_loaded_tokenizer = None
self.last_created_pipeline = None
self.last_loaded_model_key = None

def load_models(self, model_name: str,
Expand All @@ -39,10 +38,11 @@ def load_models(self, model_name: str,
model_name, cache_dir=cache_dir, use_auth_token=token)
self.last_loaded_tokenizer = self.tokenizer.from_pretrained(
model_name, cache_dir=cache_dir, use_auth_token=token)
self.last_created_pipeline = self.pipeline(
last_created_pipeline = self.pipeline(
self.task_name,
model=self.last_loaded_model,
tokenizer=self.last_loaded_tokenizer,
device=self.device,
framework="pt")
self.last_loaded_model_key = current_model_key
return last_created_pipeline
1 change: 1 addition & 0 deletions tests/unit_tests/udfs/base_model_dummy_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def load_models(self, model_name: str,
model_name, cache_dir=cache_dir, use_auth_token=token)
self.last_loaded_tokenizer = self.tokenizer.from_pretrained(
model_name, cache_dir=cache_dir, use_auth_token=token)
return None


class DummyImplementationUDF(BaseModelUDF):
Expand Down

0 comments on commit bbf3505

Please sign in to comment.