Skip to content

Commit

Permalink
Fix loading of tokenizers in DPR (#2755)
Browse files Browse the repository at this point in the history
  • Loading branch information
bogdankostic authored Jul 4, 2022
1 parent 2a8b129 commit dc48c44
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 15 deletions.
9 changes: 1 addition & 8 deletions haystack/modeling/model/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,7 @@ def _infer_tokenizer_class_from_string(pretrained_model_name_or_path):
elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower():
tokenizer_class = "DPRContextEncoderTokenizer"
else:
raise ValueError(
f"Could not infer tokenizer_class from model config or "
f"name '{pretrained_model_name_or_path}'. Set arg `tokenizer_class` "
f"in Tokenizer.load() to one of: AlbertTokenizer, XLMRobertaTokenizer, "
f"RobertaTokenizer, DistilBertTokenizer, BertTokenizer, XLNetTokenizer, "
f"CamembertTokenizer, ElectraTokenizer, DPRQuestionEncoderTokenizer,"
f"DPRContextEncoderTokenizer."
)
tokenizer_class = "AutoTokenizer"

return tokenizer_class

Expand Down
2 changes: 1 addition & 1 deletion haystack/nodes/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(
)

self.infer_tokenizer_classes = infer_tokenizer_classes
tokenizers_default_classes = {"query": "AutoTokenizer", "passage": "AutoTokenizer"}
tokenizers_default_classes = {"query": "DPRQuestionEncoderTokenizer", "passage": "DPRContextEncoderTokenizer"}
if self.infer_tokenizer_classes:
tokenizers_default_classes["query"] = None # type: ignore
tokenizers_default_classes["passage"] = None # type: ignore
Expand Down
12 changes: 6 additions & 6 deletions test/nodes/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from haystack.document_stores import MilvusDocumentStore
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast, PreTrainedTokenizerFast
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast

from ..conftest import SAMPLES_PATH

Expand Down Expand Up @@ -277,8 +277,8 @@ def sum_params(model):
assert loaded_retriever.processor.max_seq_len_query == 64

# Tokenizer
assert isinstance(loaded_retriever.passage_tokenizer, PreTrainedTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, PreTrainedTokenizerFast)
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
assert loaded_retriever.passage_tokenizer.do_lower_case == True
assert loaded_retriever.query_tokenizer.do_lower_case == True
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
Expand Down Expand Up @@ -320,9 +320,9 @@ def sum_params(model):
assert loaded_retriever.processor.max_seq_len_query == 64

# Tokenizer
assert isinstance(loaded_retriever.passage_tokenizer, PreTrainedTokenizerFast)
assert isinstance(loaded_retriever.table_tokenizer, PreTrainedTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, PreTrainedTokenizerFast)
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
assert isinstance(loaded_retriever.table_tokenizer, DPRContextEncoderTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
assert loaded_retriever.passage_tokenizer.do_lower_case == True
assert loaded_retriever.table_tokenizer.do_lower_case == True
assert loaded_retriever.query_tokenizer.do_lower_case == True
Expand Down

0 comments on commit dc48c44

Please sign in to comment.