-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: ONNX FARMReader model conversion is broken #3211
Conversation
@@ -626,8 +626,8 @@ def convert_to_onnx( | |||
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained | |||
:return: None. | |||
""" | |||
language_model_class = LanguageModel.get_language_model_class(model_name) | |||
if language_model_class not in ["Bert", "Roberta", "XLMRoberta"]: | |||
language_model = get_language_model(model_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might use here this function:
def _get_model_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will try, thanks Bogdan
@bogdankostic when I move the onnx test to be executed after test_retrieval (mdr+memory) then both tests pass. |
Create new onnx tests, add first test
@bogdankostic with your suggested change this PR started to pass 🚀 |
I added dedicated onnx test file so we can start testing this part of the code base even better in the future |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just added some minor comments about imports that don't seem to be used.
test/nodes/test_retriever.py
Outdated
@@ -1,4 +1,5 @@ | |||
import logging | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this import is not used either.
test/nodes/test_retriever.py
Outdated
@@ -11,6 +12,7 @@ | |||
from elasticsearch import Elasticsearch | |||
|
|||
from haystack.document_stores import WeaviateDocumentStore | |||
from haystack.nodes import FARMReader |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
test/nodes/test_reader.py
Outdated
@@ -1,4 +1,7 @@ | |||
import math | |||
import os | |||
import tempfile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this import is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, that's because we don't check imports for tests, will fix - thanks @bogdankostic
Related Issues
Proposed Changes:
During a recent language and tokenization refactoring, we broke FarmReader ONNX conversion support. A simple ONNX conversion of a QA model fails; we don't have a test case for it.
How did you test it?
Added a unit test for FarmReader ONNX conversion
Notes for the reviewer
Is there a more straightforward way to figure out the model type (aside from loading it using
get_language_model
)? See PR changes.Checklist