Skip to content

Commit

Permalink
Validate max_seq_length in SquadProcessor (#2740)
Browse files Browse the repository at this point in the history
* added max_len_seq validation in SquadProcessor

* fixed string formatting

* added tests for invalid max_seq_len

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
francescocastelli and github-actions[bot] authored Jul 4, 2022
1 parent ffb7e4e commit 31dcd55
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
9 changes: 5 additions & 4 deletions haystack/modeling/data_handler/data_silo.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,11 @@ def _calculate_statistics(self):
logger.info("Proportion clipped: {}".format(clipped))
if clipped > 0.5:
logger.info(
"[Haystack Tip] {}% of your samples got cut down to {} tokens. "
"Consider increasing max_seq_len. "
"This will lead to higher memory consumption but is likely to "
"improve your model performance".format(round(clipped * 100, 1), max_seq_len)
f"[Haystack Tip] {round(clipped * 100, 1)}% of your samples got cut down to {max_seq_len} tokens. "
"Consider increasing max_seq_len "
f"(the maximum value allowed with the current model is max_seq_len={self.processor.tokenizer.model_max_length}, "
"if this is not enough consider splitting the document in smaller units or changing the model). "
"This will lead to higher memory consumption but is likely to improve your model performance"
)
elif "query_input_ids" in self.tensor_names and "passage_input_ids" in self.tensor_names:
logger.info(
Expand Down
14 changes: 14 additions & 0 deletions haystack/modeling/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,13 @@ def __init__(
"""
self.ph_output_type = "per_token_squad"

# validate max_seq_len
assert max_seq_len <= tokenizer.model_max_length, (
"max_seq_len cannot be greater than the maximum sequence length handled by the model: "
f"got max_seq_len={max_seq_len}, while the model maximum length is {tokenizer.model_max_length}. "
"Please adjust max_seq_len accordingly or use another model "
)

assert doc_stride < (max_seq_len - max_query_length), (
"doc_stride ({}) is longer than max_seq_len ({}) minus space reserved for query tokens ({}). \nThis means that there will be gaps "
"as the passage windows slide, causing the model to skip over parts of the document.\n"
Expand Down Expand Up @@ -490,6 +497,13 @@ def convert_qa_input_dict(self, infer_dict: dict):
["text", "questions"] (api format). This function converts the latter into the former. It also converts the
is_impossible field to answer_type so that NQ and SQuAD dicts have the same format.
"""
# validate again max_seq_len
assert self.max_seq_len <= self.tokenizer.model_max_length, (
"max_seq_len cannot be greater than the maximum sequence length handled by the model: "
f"got max_seq_len={self.max_seq_len}, while the model maximum length is {self.tokenizer.model_max_length}. "
"Please adjust max_seq_len accordingly or use another model "
)

# check again for doc stride vs max_seq_len when. Parameters can be changed for already initialized models (e.g. in haystack)
assert self.doc_stride < (self.max_seq_len - self.max_query_length), (
"doc_stride ({}) is longer than max_seq_len ({}) minus space reserved for query tokens ({}). \nThis means that there will be gaps "
Expand Down
24 changes: 24 additions & 0 deletions test/nodes/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,22 @@ def test_top_k(reader, docs, top_k):
print("WARNING: Could not set `top_k_per_sample` in FARM. Please update FARM version.")


def test_farm_reader_invalid_params():
# invalid max_seq_len (greater than model maximum seq length)
with pytest.raises(Exception):
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=False, max_seq_len=513)

# invalid max_seq_len (max_seq_len >= doc_stride)
with pytest.raises(Exception):
reader = FARMReader(
model_name_or_path="deepset/roberta-base-squad2", use_gpu=False, max_seq_len=129, doc_stride=128
)

# invalid doc_stride (doc_stride >= (max_seq_len - max_query_length))
with pytest.raises(Exception):
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=False, doc_stride=999)


def test_farm_reader_update_params(docs):
reader = FARMReader(
model_name_or_path="deepset/roberta-base-squad2", use_gpu=False, no_ans_boost=0, num_processes=0
Expand Down Expand Up @@ -219,6 +235,14 @@ def test_farm_reader_update_params(docs):
reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=99, doc_stride=128)
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)

# update max_seq_len with invalid value (greater than the model maximum sequence length)
with pytest.raises(Exception):
invalid_max_seq_len = reader.inferencer.processor.tokenizer.model_max_length + 1
reader.update_parameters(
context_window_size=100, no_ans_boost=-10, max_seq_len=invalid_max_seq_len, doc_stride=128
)
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)


@pytest.mark.parametrize("use_confidence_scores", [True, False])
def test_farm_reader_uses_same_sorting_as_QAPredictionHead(use_confidence_scores):
Expand Down

0 comments on commit 31dcd55

Please sign in to comment.