From 58e56b362db0c682cad57ef895e3b959f9ae3a17 Mon Sep 17 00:00:00 2001 From: Malte Pietsch Date: Mon, 17 Feb 2020 09:15:33 +0100 Subject: [PATCH 1/6] add no_answer option to aggregation of paragraphs level preds --- haystack/reader/farm.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 4f346cfec0..e51de7a976 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -53,6 +53,7 @@ def __init__( self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") self.inferencer.model.prediction_heads[0].context_window_size = context_window_size self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_threshold + self.no_ans_threshold = no_ans_threshold self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, @@ -183,16 +184,20 @@ def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, m } input_dicts.append(cur) - # get answers from QA model (Top 5 per input paragraph) + # get answers from QA model (Default: top 5 per input paragraph) predictions = self.inferencer.inference_from_dicts( dicts=input_dicts, rest_api_schema=True, max_processes=max_processes ) # assemble answers from all the different paragraphs & format them + # for the "no answer" option, we choose the no_answer score from the paragraph with the best "real answer" + # the score of this "no answer" is then "boosted" with the no_ans_gap answers = [] + best_score_answer = 0 for pred in predictions: for a in pred["predictions"][0]["answers"]: - if a["answer"]: #skip "no answer" + # skip "no answers" here + if a["answer"]: cur = {"answer": a["answer"], "score": a["score"], "probability": float(expit(np.asarray([a["score"]]) / 8)), #just a pseudo prob for now @@ -201,14 +206,28 @@ def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, m "offset_end": a["offset_answer_end"] - a["offset_context_start"], "document_id": a["document_id"]} answers.append(cur) + # if cur answer is the best, we store the gap to "no answer" in this paragraph + if a["score"] > best_score_answer: + best_score_answer = a["score"] + no_ans_gap = pred["predictions"][0]["no_ans_gap"] + no_ans_score = (best_score_answer+no_ans_gap)-self.no_ans_threshold + + # add no answer option from the paragraph with the best answer + cur = {"answer": "", + "score": no_ans_score, + "probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now + "context": "", + "offset_start": -1, + "offset_end": -1, + "document_id": None} + answers.append(cur) # sort answers by their `probability` and select top-k answers = sorted( answers, key=lambda k: k["probability"], reverse=True ) answers = answers[:top_k] - result = {"question": question, "answers": answers} - return result + return result \ No newline at end of file From 85fbf502cadf46ab2ad0904927ff372bfb8a50d8 Mon Sep 17 00:00:00 2001 From: Malte Pietsch Date: Mon, 17 Feb 2020 11:06:15 +0100 Subject: [PATCH 2/6] change offsets to FARM default --- haystack/reader/farm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index e51de7a976..d7940bb038 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -217,8 +217,8 @@ def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, m "score": no_ans_score, "probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now "context": "", - "offset_start": -1, - "offset_end": -1, + "offset_start": 0, + "offset_end": 0, "document_id": None} answers.append(cur) From dc9188361c3cba7dfc445311e66ee7088d09b956 Mon Sep 17 00:00:00 2001 From: timoeller Date: Wed, 19 Feb 2020 12:57:35 +0100 Subject: [PATCH 3/6] Add ranking of no ans relative to positive answers --- haystack/reader/farm.py | 48 +++++++++++++++++------- tutorials/Tutorial1_Basic_QA_Pipeline.py | 4 +- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index d7940bb038..ab127ff308 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -27,7 +27,7 @@ def __init__( self, model_name_or_path, context_window_size=30, - no_ans_threshold=-100, + no_ans_boost=-100, batch_size=16, use_gpu=True, n_candidates_per_passage=2): @@ -40,20 +40,24 @@ def __init__( .... See https://huggingface.co/models for full list of available models. :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. - :param no_ans_threshold: How much greater the no_answer logit needs to be over the pos_answer in order to be chosen. - The higher the value, the more `uncertain` answers are accepted + :param no_ans_boost: How much the no_answer logit is boosted/increased. + The higher the value, the more likely a "no answer possible" is returned by the model :param batch_size: Number of samples the model receives in one batch for inference :param use_gpu: Whether to use GPU (if available) :param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). Note: This is not the number of "final answers" you will receive (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) + # TODO adjust farm. n_cand = 2 returns no answer + highest positive answer + # should return no answer + 2 best positive answers + # drawback: answers from a single paragraph might be very similar in text and score + # we need to have more varied answers (by excluding overlapping answers?) """ self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") self.inferencer.model.prediction_heads[0].context_window_size = context_window_size - self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_threshold - self.no_ans_threshold = no_ans_threshold + self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_boost # TODO adjust naming and concept in FARM + self.no_ans_boost = no_ans_boost self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, @@ -189,10 +193,11 @@ def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, m dicts=input_dicts, rest_api_schema=True, max_processes=max_processes ) - # assemble answers from all the different paragraphs & format them - # for the "no answer" option, we choose the no_answer score from the paragraph with the best "real answer" - # the score of this "no answer" is then "boosted" with the no_ans_gap + # assemble answers from all the different paragraphs & format them. + # For the "no answer" option, we collect all no_ans_gaps and decide how likely + # a no answer is based on all no_ans_gaps values across all documents answers = [] + no_ans_gaps = [] best_score_answer = 0 for pred in predictions: for a in pred["predictions"][0]["answers"]: @@ -206,13 +211,29 @@ def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, m "offset_end": a["offset_answer_end"] - a["offset_context_start"], "document_id": a["document_id"]} answers.append(cur) - # if cur answer is the best, we store the gap to "no answer" in this paragraph + no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"]) if a["score"] > best_score_answer: best_score_answer = a["score"] - no_ans_gap = pred["predictions"][0]["no_ans_gap"] - no_ans_score = (best_score_answer+no_ans_gap)-self.no_ans_threshold - # add no answer option from the paragraph with the best answer + # adjust no_ans_gaps + no_ans_gaps = np.array(no_ans_gaps) + no_ans_gaps_adjusted = no_ans_gaps + self.no_ans_boost + + # We want to heuristically rank how likely or unlikely the "no answer" option is. + + # case: all documents return no answer, then all no_ans_gaps are positive + if np.sum(no_ans_gaps_adjusted < 0) == 0: + # to rank we add the smallest no_ans_gap (a document where an answer would be nearly as likely as the no anser) + # to the highest answer score we found + no_ans_score = best_score_answer + min(no_ans_gaps_adjusted) + # case: documents where answers are preferred over no answer, the no_ans_gap is negative + else: + # the lowest (highest negative) no_ans_gap would be needed as positive no_ans_boost for the + # model to return "no answer" on all documents + # we subtract this value from the best answer score to rank our "no answer" option + # magically this is the same equation as used for the case above : ) + no_ans_score = best_score_answer + min(no_ans_gaps_adjusted) + cur = {"answer": "", "score": no_ans_score, "probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now @@ -228,6 +249,7 @@ def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, m ) answers = answers[:top_k] result = {"question": question, - "answers": answers} + "answers": answers, + "min_ans_gap": min(no_ans_gaps_adjusted)} return result \ No newline at end of file diff --git a/tutorials/Tutorial1_Basic_QA_Pipeline.py b/tutorials/Tutorial1_Basic_QA_Pipeline.py index 52ed87b00b..fb8c5c8376 100755 --- a/tutorials/Tutorial1_Basic_QA_Pipeline.py +++ b/tutorials/Tutorial1_Basic_QA_Pipeline.py @@ -37,7 +37,7 @@ # Reader use more powerful but slower deep learning models # You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models) # here: a medium sized BERT QA model trained via FARM on Squad 2.0 -reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False) +reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False, no_ans_boost=0) # OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers) # reader = TransformersReader(model="distilbert-base-uncased-distilled-squad", tokenizer="distilbert-base-uncased", use_gpu=-1) @@ -48,7 +48,7 @@ ## Voilá! Ask a question! # You can configure how many candidates the reader and retriever shall return # The higher top_k_retriever, the better (but also the slower) your answers. -prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5) +prediction = finder.get_answers(question="Who is the daughter of Arya Stark?", top_k_retriever=10, top_k_reader=5) #prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5) #prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5) From c6d9da8827887b58b413a570d261e2625946f3ed Mon Sep 17 00:00:00 2001 From: timoeller Date: Wed, 19 Feb 2020 13:02:51 +0100 Subject: [PATCH 4/6] Add doc for no answer boosting --- tutorials/Tutorial1_Basic_QA_Pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tutorials/Tutorial1_Basic_QA_Pipeline.py b/tutorials/Tutorial1_Basic_QA_Pipeline.py index fb8c5c8376..eaa84c8eda 100755 --- a/tutorials/Tutorial1_Basic_QA_Pipeline.py +++ b/tutorials/Tutorial1_Basic_QA_Pipeline.py @@ -37,6 +37,7 @@ # Reader use more powerful but slower deep learning models # You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models) # here: a medium sized BERT QA model trained via FARM on Squad 2.0 +# You can adjust the model to return "no answer possible" with the no_ans_boost. Higher values mean the model prefers "no answer possible" reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False, no_ans_boost=0) # OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers) From dface98eca143aad55a1121b37110b3e5d5e9094 Mon Sep 17 00:00:00 2001 From: timoeller Date: Wed, 19 Feb 2020 14:48:50 +0100 Subject: [PATCH 5/6] Adjust printing of all details --- haystack/reader/farm.py | 6 +++--- haystack/utils.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index ab127ff308..df534096f7 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -234,7 +234,7 @@ def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, m # magically this is the same equation as used for the case above : ) no_ans_score = best_score_answer + min(no_ans_gaps_adjusted) - cur = {"answer": "", + cur = {"answer": "[computer says no answer is likely]", "score": no_ans_score, "probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now "context": "", @@ -249,7 +249,7 @@ def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, m ) answers = answers[:top_k] result = {"question": question, - "answers": answers, - "min_ans_gap": min(no_ans_gaps_adjusted)} + "adjust_no_ans_boost": -min(no_ans_gaps_adjusted), + "answers": answers} return result \ No newline at end of file diff --git a/haystack/utils.py b/haystack/utils.py index ad6e80b4a9..5a4ca2d7fa 100644 --- a/haystack/utils.py +++ b/haystack/utils.py @@ -22,8 +22,10 @@ def print_answers(results, details="all"): for key in keys_to_drop: if key in a: del a[key] - # print them - pp.pprint(answers) + + pp.pprint(answers) + else: + pp.pprint(results) def convert_labels_to_squad(labels_file): From 840b368732cf969e5f345fa59979dcf248f75f2b Mon Sep 17 00:00:00 2001 From: timoeller Date: Wed, 19 Feb 2020 14:51:12 +0100 Subject: [PATCH 6/6] Add no ans example --- tutorials/Tutorial1_Basic_QA_Pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tutorials/Tutorial1_Basic_QA_Pipeline.py b/tutorials/Tutorial1_Basic_QA_Pipeline.py index eaa84c8eda..1b117e7c7f 100755 --- a/tutorials/Tutorial1_Basic_QA_Pipeline.py +++ b/tutorials/Tutorial1_Basic_QA_Pipeline.py @@ -49,8 +49,9 @@ ## Voilá! Ask a question! # You can configure how many candidates the reader and retriever shall return # The higher top_k_retriever, the better (but also the slower) your answers. -prediction = finder.get_answers(question="Who is the daughter of Arya Stark?", top_k_retriever=10, top_k_reader=5) +prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5) +#prediction = finder.get_answers(question="Who is the daughter of Arya Stark?", top_k_reader=5) # impossible question test #prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5) #prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)