Skip to content

Commit

Permalink
Match answer sorting in QuestionAnsweringHead with FARMReader (#2414
Browse files Browse the repository at this point in the history
)

* match no_answer confidence

* Update Documentation & Code Style

* test added

* Update Documentation & Code Style

* fix tests

* Update Documentation & Code Style

* apply penalties of scores to confidences too

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
tstadel and github-actions[bot] authored Apr 21, 2022
1 parent 4bf4702 commit 25475a6
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 30 deletions.
4 changes: 3 additions & 1 deletion docs/_src/api/api/reader.md
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ or use the Reader's device by default.
#### eval

```python
def eval(document_store: BaseDocumentStore, device: Optional[Union[str, torch.device]] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold-label", calibrate_conf_scores: bool = False)
def eval(document_store: BaseDocumentStore, device: Optional[Union[str, torch.device]] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold-label", calibrate_conf_scores: bool = False, use_no_answer_legacy_confidence=False)
```

Performs evaluation on evaluation documents in the DocumentStore.
Expand All @@ -450,6 +450,8 @@ or use the Reader's device by default.
- `doc_index`: Index/Table name where documents that are used for evaluation are stored
- `label_origin`: Field name where the gold labels are stored
- `calibrate_conf_scores`: Whether to calibrate the temperature for temperature scaling of the confidence scores
- `use_no_answer_legacy_confidence`: Whether to use the legacy confidence definition for no_answer: difference between the best overall answer confidence and the no_answer gap confidence.
Otherwise we use the no_answer score normalized to a range of [0,1] by an expit function (default).

<a id="farm.FARMReader.calibrate_confidence_scores"></a>

Expand Down
12 changes: 11 additions & 1 deletion haystack/modeling/evaluation/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,27 @@ def __init__(self, data_loader: torch.utils.data.DataLoader, tasks, device: torc
self.report = report

def eval(
self, model: AdaptiveModel, return_preds_and_labels: bool = False, calibrate_conf_scores: bool = False
self,
model: AdaptiveModel,
return_preds_and_labels: bool = False,
calibrate_conf_scores: bool = False,
use_confidence_scores_for_ranking=True,
use_no_answer_legacy_confidence=False,
) -> List[Dict]:
"""
Performs evaluation on a given model.
:param model: The model on which to perform evaluation
:param return_preds_and_labels: Whether to add preds and labels in the returned dicts of the
:param calibrate_conf_scores: Whether to calibrate the temperature for temperature scaling of the confidence scores
:param use_confidence_scores_for_ranking: Whether to sort answers by confidence score (normalized between 0 and 1)(default) or by standard score (unbounded).
:param use_no_answer_legacy_confidence: Whether to use the legacy confidence definition for no_answer: difference between the best overall answer confidence and the no_answer gap confidence.
Otherwise we use the no_answer score normalized to a range of [0,1] by an expit function (default).
:return: all_results: A list of dictionaries, one for each prediction head. Each dictionary contains the metrics
and reports generated during evaluation.
"""
model.prediction_heads[0].use_confidence_scores_for_ranking = use_confidence_scores_for_ranking
model.prediction_heads[0].use_no_answer_legacy_confidence = use_no_answer_legacy_confidence
model.eval()

# init empty lists per prediction head
Expand Down
19 changes: 15 additions & 4 deletions haystack/modeling/model/prediction_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch import optim
from torch.nn import CrossEntropyLoss, NLLLoss
from transformers import AutoModelForQuestionAnswering
from scipy.special import expit

from haystack.modeling.data_handler.samples import SampleBasket
from haystack.modeling.model.predictions import QACandidate, QAPred
Expand Down Expand Up @@ -234,7 +235,8 @@ def __init__(
n_best_per_sample: Optional[int] = None,
duplicate_filtering: int = -1,
temperature_for_confidence: float = 1.0,
use_confidence_scores_for_ranking: bool = False,
use_confidence_scores_for_ranking: bool = True,
use_no_answer_legacy_confidence: bool = False,
**kwargs,
):
"""
Expand All @@ -250,7 +252,9 @@ def __init__(
:param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered.
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
:param temperature_for_confidence: The divisor that is used to scale logits to calibrate confidence scores
:param use_confidence_scores_for_ranking: Whether to sort answers by confidence score (normalized between 0 and 1) or by standard score (unbounded)(default).
:param use_confidence_scores_for_ranking: Whether to sort answers by confidence score (normalized between 0 and 1)(default) or by standard score (unbounded).
:param use_no_answer_legacy_confidence: Whether to use the legacy confidence definition for no_answer: difference between the best overall answer confidence and the no_answer gap confidence.
Otherwise we use the no_answer score normalized to a range of [0,1] by an expit function (default).
"""
super(QuestionAnsweringHead, self).__init__()
if len(kwargs) > 0:
Expand Down Expand Up @@ -279,6 +283,7 @@ def __init__(
self.generate_config()
self.temperature_for_confidence = nn.Parameter(torch.ones(1) * temperature_for_confidence)
self.use_confidence_scores_for_ranking = use_confidence_scores_for_ranking
self.use_no_answer_legacy_confidence = use_no_answer_legacy_confidence

@classmethod
def load(cls, pretrained_model_name_or_path: Union[str, Path], revision: Optional[str] = None, **kwargs): # type: ignore
Expand Down Expand Up @@ -520,7 +525,11 @@ def get_top_candidates(self, sorted_candidates, start_end_matrix, sample_idx: in
if self.duplicate_filtering > -1 and (start_idx in start_idx_candidates or end_idx in end_idx_candidates):
continue
score = start_end_matrix[start_idx, end_idx].item()
confidence = (start_matrix_softmax_start[start_idx].item() + end_matrix_softmax_end[end_idx].item()) / 2
confidence = (
(start_matrix_softmax_start[start_idx].item() + end_matrix_softmax_end[end_idx].item()) / 2
if score > -500
else np.exp(score / 10) # disqualify answers according to scores in logits_to_preds()
)
top_candidates.append(
QACandidate(
offset_answer_start=start_idx,
Expand Down Expand Up @@ -795,7 +804,9 @@ def reduce_preds(self, preds):
aggregation_level="document",
passage_id=None,
n_passages_in_doc=n_samples,
confidence=best_overall_positive_confidence - no_ans_gap_confidence,
confidence=best_overall_positive_confidence - no_ans_gap_confidence
if self.use_no_answer_legacy_confidence
else float(expit(np.asarray(best_overall_positive_score - no_ans_gap) / 8)),
)

# Add no answer to positive answers, sort the order and return the n_best
Expand Down
21 changes: 12 additions & 9 deletions haystack/nodes/reader/farm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,9 @@ def __init__(
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer
try:
self.inferencer.model.prediction_heads[0].n_best_per_sample = top_k_per_sample
except:
logger.warning("Could not set `top_k_per_sample` in FARM. Please update FARM version.")
try:
self.inferencer.model.prediction_heads[0].duplicate_filtering = duplicate_filtering
except:
logger.warning("Could not set `duplicate_filtering` in FARM. Please update FARM version.")
self.inferencer.model.prediction_heads[0].n_best_per_sample = top_k_per_sample
self.inferencer.model.prediction_heads[0].duplicate_filtering = duplicate_filtering
self.inferencer.model.prediction_heads[0].use_confidence_scores_for_ranking = use_confidence_scores
self.max_seq_len = max_seq_len
self.progress_bar = progress_bar
self.use_confidence_scores = use_confidence_scores
Expand Down Expand Up @@ -846,6 +841,7 @@ def eval(
doc_index: str = "eval_document",
label_origin: str = "gold-label",
calibrate_conf_scores: bool = False,
use_no_answer_legacy_confidence=False,
):
"""
Performs evaluation on evaluation documents in the DocumentStore.
Expand All @@ -862,6 +858,8 @@ def eval(
:param doc_index: Index/Table name where documents that are used for evaluation are stored
:param label_origin: Field name where the gold labels are stored
:param calibrate_conf_scores: Whether to calibrate the temperature for temperature scaling of the confidence scores
:param use_no_answer_legacy_confidence: Whether to use the legacy confidence definition for no_answer: difference between the best overall answer confidence and the no_answer gap confidence.
Otherwise we use the no_answer score normalized to a range of [0,1] by an expit function (default).
"""
if device is None:
device = self.devices[0]
Expand Down Expand Up @@ -968,7 +966,12 @@ def eval(

evaluator = Evaluator(data_loader=data_loader, tasks=self.inferencer.processor.tasks, device=device)

eval_results = evaluator.eval(self.inferencer.model, calibrate_conf_scores=calibrate_conf_scores)
eval_results = evaluator.eval(
self.inferencer.model,
calibrate_conf_scores=calibrate_conf_scores,
use_confidence_scores_for_ranking=self.use_confidence_scores,
use_no_answer_legacy_confidence=use_no_answer_legacy_confidence,
)
toc = perf_counter()
reader_time = toc - tic
results = {
Expand Down
19 changes: 14 additions & 5 deletions test/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,25 +128,34 @@ def test_add_eval_data(document_store, batch_size):

@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus1"], indirect=True)
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_eval_reader(reader, document_store: BaseDocumentStore):
@pytest.mark.parametrize("use_confidence_scores", [True, False])
def test_eval_reader(reader, document_store: BaseDocumentStore, use_confidence_scores):
# add eval data (SQUAD format)
document_store.add_eval_data(
filename=SAMPLES_PATH / "squad" / "tiny.json",
doc_index="haystack_test_eval_document",
label_index="haystack_test_feedback",
)
assert document_store.get_document_count(index="haystack_test_eval_document") == 2

reader.use_confidence_scores = use_confidence_scores

# eval reader
reader_eval_results = reader.eval(
document_store=document_store,
label_index="haystack_test_feedback",
doc_index="haystack_test_eval_document",
device="cpu",
)
assert reader_eval_results["f1"] > 66.65
assert reader_eval_results["f1"] < 66.67
assert reader_eval_results["EM"] == 50
assert reader_eval_results["top_n_accuracy"] == 100.0

if use_confidence_scores:
assert reader_eval_results["f1"] == 50
assert reader_eval_results["EM"] == 50
assert reader_eval_results["top_n_accuracy"] == 100.0
else:
assert 66.67 > reader_eval_results["f1"] > 66.65
assert reader_eval_results["EM"] == 50
assert reader_eval_results["top_n_accuracy"] == 100.0


@pytest.mark.elasticsearch
Expand Down
24 changes: 14 additions & 10 deletions test/test_modeling_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,8 @@ def test_span_inference_result_ranking_by_confidence(bert_base_squad2, caplog=No
questions=Question("Who counted the game among the best ever made?", uid="best_id_ever"),
)
]
result = bert_base_squad2.inference_from_objects(obj_input, return_json=False)[0]

# by default, result is sorted by score and not by confidence
assert all(result.prediction[i].score >= result.prediction[i + 1].score for i in range(len(result.prediction) - 1))
assert not all(
result.prediction[i].confidence >= result.prediction[i + 1].confidence
for i in range(len(result.prediction) - 1)
)

# ranking can be adjusted so that result is sorted by confidence
bert_base_squad2.model.prediction_heads[0].use_confidence_scores_for_ranking = True
# by default, result is sorted by confidence and not by score
result_ranked_by_confidence = bert_base_squad2.inference_from_objects(obj_input, return_json=False)[0]
assert all(
result_ranked_by_confidence.prediction[i].confidence >= result_ranked_by_confidence.prediction[i + 1].confidence
Expand All @@ -85,6 +76,18 @@ def test_span_inference_result_ranking_by_confidence(bert_base_squad2, caplog=No
for i in range(len(result_ranked_by_confidence.prediction) - 1)
)

# ranking can be adjusted so that result is sorted by score
bert_base_squad2.model.prediction_heads[0].use_confidence_scores_for_ranking = False
result_ranked_by_score = bert_base_squad2.inference_from_objects(obj_input, return_json=False)[0]
assert all(
result_ranked_by_score.prediction[i].score >= result_ranked_by_score.prediction[i + 1].score
for i in range(len(result_ranked_by_score.prediction) - 1)
)
assert not all(
result_ranked_by_score.prediction[i].confidence >= result_ranked_by_score.prediction[i + 1].confidence
for i in range(len(result_ranked_by_score.prediction) - 1)
)


def test_inference_objs(span_inference_result, caplog=None):
if caplog:
Expand Down Expand Up @@ -226,6 +229,7 @@ def test_no_duplicate_answer_filtering(bert_base_squad2):
bert_base_squad2.model.prediction_heads[0].n_best = 5
bert_base_squad2.model.prediction_heads[0].n_best_per_sample = 5
bert_base_squad2.model.prediction_heads[0].duplicate_filtering = -1
bert_base_squad2.model.prediction_heads[0].no_ans_boost = -100.0

result = bert_base_squad2.inference_from_dicts(dicts=qa_input)
offset_answer_starts = []
Expand Down
36 changes: 36 additions & 0 deletions test/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

import pytest
from haystack.modeling.data_handler.inputs import QAInput, Question

from haystack.schema import Document, Answer
from haystack.nodes.reader.base import BaseReader
Expand Down Expand Up @@ -169,3 +170,38 @@ def test_farm_reader_update_params(test_docs_xs):
with pytest.raises(Exception):
reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=99, doc_stride=128)
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)


@pytest.mark.parametrize("use_confidence_scores", [True, False])
def test_farm_reader_uses_same_sorting_as_QAPredictionHead(use_confidence_scores):
reader = FARMReader(
model_name_or_path="deepset/roberta-base-squad2",
use_gpu=False,
num_processes=0,
return_no_answer=True,
use_confidence_scores=use_confidence_scores,
)

text = """Beer is one of the oldest[1][2][3] and most widely consumed[4] alcoholic drinks in the world, and the third most popular drink overall after water and tea.[5] It is produced by the brewing and fermentation of starches, mainly derived from cereal grains—most commonly from malted barley, though wheat, maize (corn), rice, and oats are also used. During the brewing process, fermentation of the starch sugars in the wort produces ethanol and carbonation in the resulting beer.[6] Most modern beer is brewed with hops, which add bitterness and other flavours and act as a natural preservative and stabilizing agent. Other flavouring agents such as gruit, herbs, or fruits may be included or used instead of hops. In commercial brewing, the natural carbonation effect is often removed during processing and replaced with forced carbonation.[7]
Some of humanity's earliest known writings refer to the production and distribution of beer: the Code of Hammurabi included laws regulating beer and beer parlours,[8] and "The Hymn to Ninkasi", a prayer to the Mesopotamian goddess of beer, served as both a prayer and as a method of remembering the recipe for beer in a culture with few literate people.[9][10]
Beer is distributed in bottles and cans and is also commonly available on draught, particularly in pubs and bars. The brewing industry is a global business, consisting of several dominant multinational companies and many thousands of smaller producers ranging from brewpubs to regional breweries. The strength of modern beer is usually around 4% to 6% alcohol by volume (ABV), although it may vary between 0.5% and 20%, with some breweries creating examples of 40% ABV and above.[11]
Beer forms part of the culture of many nations and is associated with social traditions such as beer festivals, as well as a rich pub culture involving activities like pub crawling, pub quizzes and pub games.
When beer is distilled, the resulting liquor is a form of whisky.[12]
"""

docs = [Document(text)]
query = "What is the third most popular drink?"

reader_predictions = reader.predict(query=query, documents=docs, top_k=5)

farm_input = [QAInput(doc_text=d.content, questions=Question(query)) for d in docs]
inferencer_predictions = reader.inferencer.inference_from_objects(farm_input, return_json=False)

for answer, qa_cand in zip(reader_predictions["answers"], inferencer_predictions[0].prediction):
assert answer.answer == ("" if qa_cand.answer_type == "no_answer" else qa_cand.answer)
assert answer.offsets_in_document[0].start == qa_cand.offset_answer_start
assert answer.offsets_in_document[0].end == qa_cand.offset_answer_end
if use_confidence_scores:
assert answer.score == qa_cand.confidence
else:
assert answer.score == qa_cand.score

0 comments on commit 25475a6

Please sign in to comment.