diff --git a/exasol_transformers_extension/udfs/models/base_model_udf.py b/exasol_transformers_extension/udfs/models/base_model_udf.py index 675ca10d..772e3b85 100644 --- a/exasol_transformers_extension/udfs/models/base_model_udf.py +++ b/exasol_transformers_extension/udfs/models/base_model_udf.py @@ -7,6 +7,7 @@ from exasol_transformers_extension.deployment import constants from exasol_transformers_extension.utils import device_management, \ bucketfs_operations, dataframe_operations +from exasol_transformers_extension.utils.load_model import LoadModel class BaseModelUDF(ABC): @@ -20,7 +21,7 @@ class BaseModelUDF(ABC): - creates model pipeline through transformer api - manages the creation of predictions and the preparation of results. """ - +# todo does the token con change? (if yes need to be give at function call not class creation) def __init__(self, exa, batch_size, @@ -36,15 +37,17 @@ def __init__(self, self.task_name = task_name self.device = None self.cache_dir = None - self.last_loaded_model_key = None - self.last_loaded_model = None - self.last_loaded_tokenizer = None - self.last_created_pipeline = None + self.model_loader = None self.new_columns = [] def run(self, ctx): device_id = ctx.get_dataframe(1).iloc[0]['device_id'] self.device = device_management.get_torch_device(device_id) + self.model_loader = LoadModel(self.pipeline, + self.base_model, + self.tokenizer, + self.task_name, + self.device) ctx.reset() while True: @@ -171,11 +174,10 @@ def check_cache(self, model_df: pd.DataFrame) -> None: token_conn = model_df["token_conn"].iloc[0] current_model_key = (bucketfs_conn, sub_dir, model_name, token_conn) - if self.last_loaded_model_key != current_model_key: + if self.model_loader.last_loaded_model_key != current_model_key: self.set_cache_dir(model_name, bucketfs_conn, sub_dir) self.clear_device_memory() - self.load_models(model_name, token_conn) - self.last_loaded_model_key = current_model_key + self.model_loader.load_models(model_name, current_model_key, self.cache_dir, self.exa.get_connection(token_conn)) def set_cache_dir( self, model_name: str, bucketfs_conn_name: str, @@ -195,40 +197,15 @@ def set_cache_dir( self.cache_dir = bucketfs_operations.get_local_bucketfs_path( bucketfs_location=bucketfs_location, model_path=str(model_path)) + # todo move this also? def clear_device_memory(self): """ Delete models and free device memory """ - self.last_loaded_model = None - self.last_loaded_tokenizer = None + self.model_loader.last_loaded_model = None + self.model_loader.last_loaded_tokenizer = None torch.cuda.empty_cache() - def load_models(self, model_name: str, token_conn_name: str) -> None: - """ - Load model and tokenizer model from the cached location in bucketfs. - If the desired model is not cached, this method will attempt to - download the model to the read-only path /bucket/.. and cause an error. - This error will be addressed in ticket - https://github.com/exasol/transformers-extension/issues/43. - - :param model_name: The model name to be loaded - """ - token = False - if token_conn_name: - token_conn_obj = self.exa.get_connection(token_conn_name) - token = token_conn_obj.password - - self.last_loaded_model = self.base_model.from_pretrained( - model_name, cache_dir=self.cache_dir, use_auth_token=token) - self.last_loaded_tokenizer = self.tokenizer.from_pretrained( - model_name, cache_dir=self.cache_dir, use_auth_token=token) - self.last_created_pipeline = self.pipeline( - self.task_name, - model=self.last_loaded_model, - tokenizer=self.last_loaded_tokenizer, - device=self.device, - framework="pt") - def get_prediction(self, model_df: pd.DataFrame) -> pd.DataFrame: """ Perform prediction of the given model and preparation of the prediction diff --git a/exasol_transformers_extension/utils/load_model.py b/exasol_transformers_extension/utils/load_model.py new file mode 100644 index 00000000..226d2e5f --- /dev/null +++ b/exasol_transformers_extension/utils/load_model.py @@ -0,0 +1,48 @@ + + +class LoadModel: + def __init__(self, + pipeline, + base_model, + tokenizer, + task_name, + device + ): + self.pipeline = pipeline + self.base_model = base_model + self.tokenizer = tokenizer + self.task_name = task_name + self.device = device + self.last_loaded_model = None + self.last_loaded_tokenizer = None + self.last_created_pipeline = None + self.last_loaded_model_key = None + + def load_models(self, model_name: str, + current_model_key, + cache_dir, + token_conn_obj) -> None: + """ + Load model and tokenizer model from the cached location in bucketfs. + If the desired model is not cached, this method will attempt to + download the model to the read-only path /bucket/.. and cause an error. + This error will be addressed in ticket + https://github.com/exasol/transformers-extension/issues/43. + + :param model_name: The model name to be loaded + """ + token = False + if token_conn_obj: + token = token_conn_obj.password + + self.last_loaded_model = self.base_model.from_pretrained( + model_name, cache_dir=cache_dir, use_auth_token=token) + self.last_loaded_tokenizer = self.tokenizer.from_pretrained( + model_name, cache_dir=cache_dir, use_auth_token=token) + self.last_created_pipeline = self.pipeline( + self.task_name, + model=self.last_loaded_model, + tokenizer=self.last_loaded_tokenizer, + device=self.device, + framework="pt") + self.last_loaded_model_key = current_model_key