Skip to content

Commit

Permalink
Merge branch 'main' into refactoring/159-prepare-deployment-migration
Browse files Browse the repository at this point in the history
  • Loading branch information
tkilias authored Dec 11, 2023
2 parents 4fbb826 + 4629152 commit 914f862
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 74 deletions.
24 changes: 24 additions & 0 deletions doc/changes/changes_0.7.0.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Transformers Extension 0.7.0, released T.B.D

Code name: T.B.D


## Summary

T.B.D

### Features

### Bug Fixes

### Refactorings

- #144: Extracted base_model_udf.load_models into separate class


### Documentation



### Security
- #144: Updated Cryptography to version 41.0.7
66 changes: 24 additions & 42 deletions exasol_transformers_extension/udfs/models/base_model_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from exasol_transformers_extension.deployment import constants
from exasol_transformers_extension.utils import device_management, \
bucketfs_operations, dataframe_operations
from exasol_transformers_extension.utils.load_model import LoadModel


class BaseModelUDF(ABC):
Expand All @@ -20,7 +21,6 @@ class BaseModelUDF(ABC):
- creates model pipeline through transformer api
- manages the creation of predictions and the preparation of results.
"""

def __init__(self,
exa,
batch_size,
Expand All @@ -36,15 +36,14 @@ def __init__(self,
self.task_name = task_name
self.device = None
self.cache_dir = None
self.last_loaded_model_key = None
self.last_loaded_model = None
self.last_loaded_tokenizer = None
self.model_loader = None
self.last_created_pipeline = None
self.new_columns = []

def run(self, ctx):
device_id = ctx.get_dataframe(1).iloc[0]['device_id']
self.device = device_management.get_torch_device(device_id)
self.create_model_loader()
ctx.reset()

while True:
Expand All @@ -54,7 +53,17 @@ def run(self, ctx):
predictions_df = self.get_predictions_from_batch(batch_df)
ctx.emit(predictions_df)

self.clear_device_memory()
self.model_loader.clear_device_memory()

def create_model_loader(self):
"""
Creates the model_loader.
"""
self.model_loader = LoadModel(self.pipeline,
self.base_model,
self.tokenizer,
self.task_name,
self.device)

def get_predictions_from_batch(self, batch_df: pd.DataFrame) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -171,11 +180,17 @@ def check_cache(self, model_df: pd.DataFrame) -> None:
token_conn = model_df["token_conn"].iloc[0]

current_model_key = (bucketfs_conn, sub_dir, model_name, token_conn)
if self.last_loaded_model_key != current_model_key:
if self.model_loader.last_loaded_model_key != current_model_key:
self.set_cache_dir(model_name, bucketfs_conn, sub_dir)
self.clear_device_memory()
self.load_models(model_name, token_conn)
self.last_loaded_model_key = current_model_key
self.model_loader.clear_device_memory()
if token_conn:
token_conn_obj = self.exa.get_connection(token_conn)
else:
token_conn_obj = None
self.last_created_pipeline = self.model_loader.load_models(model_name,
current_model_key,
self.cache_dir,
token_conn_obj)

def set_cache_dir(
self, model_name: str, bucketfs_conn_name: str,
Expand All @@ -195,39 +210,6 @@ def set_cache_dir(
self.cache_dir = bucketfs_operations.get_local_bucketfs_path(
bucketfs_location=bucketfs_location, model_path=str(model_path))

def clear_device_memory(self):
"""
Delete models and free device memory
"""
self.last_loaded_model = None
self.last_loaded_tokenizer = None
torch.cuda.empty_cache()

def load_models(self, model_name: str, token_conn_name: str) -> None:
"""
Load model and tokenizer model from the cached location in bucketfs.
If the desired model is not cached, this method will attempt to
download the model to the read-only path /bucket/.. and cause an error.
This error will be addressed in ticket
https://github.com/exasol/transformers-extension/issues/43.
:param model_name: The model name to be loaded
"""
token = False
if token_conn_name:
token_conn_obj = self.exa.get_connection(token_conn_name)
token = token_conn_obj.password

self.last_loaded_model = self.base_model.from_pretrained(
model_name, cache_dir=self.cache_dir, use_auth_token=token)
self.last_loaded_tokenizer = self.tokenizer.from_pretrained(
model_name, cache_dir=self.cache_dir, use_auth_token=token)
self.last_created_pipeline = self.pipeline(
self.task_name,
model=self.last_loaded_model,
tokenizer=self.last_loaded_tokenizer,
device=self.device,
framework="pt")

def get_prediction(self, model_df: pd.DataFrame) -> pd.DataFrame:
"""
Expand Down
56 changes: 56 additions & 0 deletions exasol_transformers_extension/utils/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch

class LoadModel:
def __init__(self,
pipeline,
base_model,
tokenizer,
task_name,
device
):
self.pipeline = pipeline
self.base_model = base_model
self.tokenizer = tokenizer
self.task_name = task_name
self.device = device
self.last_loaded_model = None
self.last_loaded_tokenizer = None
self.last_loaded_model_key = None

def load_models(self, model_name: str,
current_model_key,
cache_dir,
token_conn_obj) -> None:
"""
Load model and tokenizer model from the cached location in bucketfs.
If the desired model is not cached, this method will attempt to
download the model to the read-only path /bucket/.. and cause an error.
This error will be addressed in ticket
https://github.com/exasol/transformers-extension/issues/43.
:param model_name: The model name to be loaded
"""
token = False
if token_conn_obj:
token = token_conn_obj.password

self.last_loaded_model = self.base_model.from_pretrained(
model_name, cache_dir=cache_dir, use_auth_token=token)
self.last_loaded_tokenizer = self.tokenizer.from_pretrained(
model_name, cache_dir=cache_dir, use_auth_token=token)
last_created_pipeline = self.pipeline(
self.task_name,
model=self.last_loaded_model,
tokenizer=self.last_loaded_tokenizer,
device=self.device,
framework="pt")
self.last_loaded_model_key = current_model_key
return last_created_pipeline

def clear_device_memory(self):
"""
Delete models and free device memory
"""
self.last_loaded_model = None
self.last_loaded_tokenizer = None
torch.cuda.empty_cache()
48 changes: 24 additions & 24 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 0 additions & 6 deletions tests/unit_tests/udfs/base_model_dummy_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,3 @@ def create_dataframes_from_predictions(
results_df_list.append(result_df)
return results_df_list

def load_models(self, model_name: str, token_conn_name: str) -> None:
token = False
self.last_loaded_model = self.base_model.from_pretrained(
model_name, cache_dir=self.cache_dir, use_auth_token=token)
self.last_loaded_tokenizer = self.tokenizer.from_pretrained(
model_name, cache_dir=self.cache_dir, use_auth_token=token)
8 changes: 6 additions & 2 deletions tests/unit_tests/udfs/test_base_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tests.unit_tests.utils_for_udf_tests import create_mock_exa_environment, create_mock_udf_context
from tests.unit_tests.udfs.base_model_dummy_implementation import DummyImplementationUDF
from exasol_transformers_extension.utils.huggingface_hub_bucketfs_model_transfer import ModelFactoryProtocol
from exasol_transformers_extension.utils.load_model import LoadModel
from tests.utils.mock_cast import mock_cast
import re

Expand Down Expand Up @@ -78,10 +79,13 @@ def setup_tests_and_run(bucketfs_conn_name, bucketfs_conn, sub_dir, model_name):
mock_meta,
'',
None)

mock_pipeline = lambda task_name, model, tokenizer, device, framework: None
mock_ctx = create_mock_udf_context(input_data, mock_meta)
udf = DummyImplementationUDF(exa=mock_exa,
base_model=mock_base_model_factory,
tokenizer=mock_tokenizer_factory)
base_model=mock_base_model_factory,
tokenizer=mock_tokenizer_factory,
pipeline=mock_pipeline)
udf.run(mock_ctx)
res = mock_ctx.output
return res, mock_meta
Expand Down

0 comments on commit 914f862

Please sign in to comment.