-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#144: Extracted load_models into seperate class #161
Changes from 4 commits
360297c
987a1ef
bbf3505
aa871fe
3e751bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Transformers Extension 0.7.0, released T.B.D | ||
|
||
Code name: T.B.D | ||
|
||
|
||
## Summary | ||
|
||
T.B.D | ||
|
||
### Features | ||
|
||
### Bug Fixes | ||
|
||
### Refactorings | ||
|
||
- #144: Extracted base_udf.load_models into separate class | ||
|
||
|
||
### Documentation | ||
|
||
|
||
|
||
### Security |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
|
||
|
||
class LoadModel: | ||
def __init__(self, | ||
pipeline, | ||
base_model, | ||
tokenizer, | ||
task_name, | ||
device | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding type annotations and/or parameter descriptions would be helpful. |
||
self.pipeline = pipeline | ||
self.base_model = base_model | ||
self.tokenizer = tokenizer | ||
self.task_name = task_name | ||
self.device = device | ||
self.last_loaded_model = None | ||
self.last_loaded_tokenizer = None | ||
self.last_loaded_model_key = None | ||
|
||
def load_models(self, model_name: str, | ||
current_model_key, | ||
cache_dir, | ||
token_conn_obj) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we provide type annotations for all parameters? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. regarding the documentation strings and stuff: this whole class will be replaced with the one that will be created in #145. so i would rather spend the time to do proper docu for the new class and leave this one as is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say the released code should be in a state of completeness. If the class morphs into something else so will the documentation. Plus it doesn't take long to get the docstrings sorted. |
||
""" | ||
Load model and tokenizer model from the cached location in bucketfs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe "Load the language model and tokenizer model" or "Load the model and tokenizer" |
||
If the desired model is not cached, this method will attempt to | ||
download the model to the read-only path /bucket/.. and cause an error. | ||
This error will be addressed in ticket | ||
https://github.com/exasol/transformers-extension/issues/43. | ||
|
||
:param model_name: The model name to be loaded | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "The model name to be loaded" => "The name of the model to be loaded". Other parameters' description? The description of the function doesn't reflect everything that the function is doing. For example, it doesn't say it will create and return a pipeline. |
||
token = False | ||
if token_conn_obj: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new bug ticket added here: #163 |
||
token = token_conn_obj.password | ||
|
||
self.last_loaded_model = self.base_model.from_pretrained( | ||
model_name, cache_dir=cache_dir, use_auth_token=token) | ||
self.last_loaded_tokenizer = self.tokenizer.from_pretrained( | ||
model_name, cache_dir=cache_dir, use_auth_token=token) | ||
last_created_pipeline = self.pipeline( | ||
self.task_name, | ||
model=self.last_loaded_model, | ||
tokenizer=self.last_loaded_tokenizer, | ||
device=self.device, | ||
framework="pt") | ||
self.last_loaded_model_key = current_model_key | ||
return last_created_pipeline |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,38 @@ | |
BaseModelUDF | ||
|
||
|
||
class DummyModelLoader: | ||
""" | ||
Create a Dummy model loader that does not create a transformers Pipeline, | ||
since that fails with test data. | ||
""" | ||
def __init__(self, | ||
base_model, | ||
tokenizer, | ||
task_name, | ||
device | ||
): | ||
self.base_model = base_model | ||
self.tokenizer = tokenizer | ||
self.task_name = task_name | ||
self.device = device | ||
self.last_loaded_model = None | ||
self.last_loaded_tokenizer = None | ||
self.last_created_pipeline = None | ||
self.last_loaded_model_key = None | ||
|
||
def load_models(self, model_name: str, | ||
current_model_key, | ||
cache_dir, | ||
token_conn_obj) -> None: | ||
token = False | ||
self.last_loaded_model = self.base_model.from_pretrained( | ||
model_name, cache_dir=cache_dir, use_auth_token=token) | ||
self.last_loaded_tokenizer = self.tokenizer.from_pretrained( | ||
model_name, cache_dir=cache_dir, use_auth_token=token) | ||
return None | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why can't you use the actual ModelLoader with a pipeline-like function that does nothing and returns None? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because the model loader is only initialized at run time of the udf because it gets input about the device thats only known then. i dont know of a way to change the functioncall at that point There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If all you need is to make a
But I am not sure you need it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are completely right, i misunderstood. changed |
||
class DummyImplementationUDF(BaseModelUDF): | ||
def __init__(self, | ||
exa, | ||
|
@@ -44,9 +76,9 @@ def create_dataframes_from_predictions( | |
results_df_list.append(result_df) | ||
return results_df_list | ||
|
||
def load_models(self, model_name: str, token_conn_name: str) -> None: | ||
token = False | ||
self.last_loaded_model = self.base_model.from_pretrained( | ||
model_name, cache_dir=self.cache_dir, use_auth_token=token) | ||
self.last_loaded_tokenizer = self.tokenizer.from_pretrained( | ||
model_name, cache_dir=self.cache_dir, use_auth_token=token) | ||
def create_model_loader(self): | ||
""" overwrite the model loader creation with dummy model loader creation""" | ||
self.model_loader = DummyModelLoader(self.base_model, | ||
self.tokenizer, | ||
self.task_name, | ||
self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a
clear
method to the ModelLoader setting these two variables to None.If
last_laded_model
andlast_loaded_tokenizer
still need to be visible I would make them properties.BTW, if the
last_created_pipeline
keeps references to these objects they won't be garbage-collected. So you probably need to it None too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will move the clear method from base udf also