-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Calculate official hotpot EM and F1 scores (#292)
- Loading branch information
1 parent
6d85165
commit 4f27455
Showing
3 changed files
with
161 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,72 @@ | ||
from deepeval.metrics import GEval | ||
from deepeval.test_case import LLMTestCaseParams | ||
from deepeval.metrics import BaseMetric, GEval | ||
from deepeval.test_case import LLMTestCase, LLMTestCaseParams | ||
|
||
from evals.official_hotpot_metrics import exact_match_score, f1_score | ||
|
||
correctness_metric = GEval( | ||
name="Correctness", | ||
model="gpt-4o-mini", | ||
evaluation_params=[ | ||
LLMTestCaseParams.ACTUAL_OUTPUT, | ||
LLMTestCaseParams.EXPECTED_OUTPUT | ||
], | ||
evaluation_steps=[ | ||
"Determine whether the actual output is factually correct based on the expected output." | ||
] | ||
) | ||
name="Correctness", | ||
model="gpt-4o-mini", | ||
evaluation_params=[ | ||
LLMTestCaseParams.ACTUAL_OUTPUT, | ||
LLMTestCaseParams.EXPECTED_OUTPUT | ||
], | ||
evaluation_steps=[ | ||
"Determine whether the actual output is factually correct based on the expected output." | ||
] | ||
) | ||
|
||
|
||
class f1_score_metric(BaseMetric): | ||
|
||
"""F1 score taken directly from the official hotpot benchmark | ||
implementation and wrapped into a deepeval metric.""" | ||
|
||
def __init__(self, threshold: float = 0.5): | ||
self.threshold = threshold | ||
|
||
def measure(self, test_case: LLMTestCase): | ||
f1, precision, recall = f1_score( | ||
prediction=test_case.actual_output, | ||
ground_truth=test_case.expected_output, | ||
) | ||
self.score = f1 | ||
self.success = self.score >= self.threshold | ||
return self.score | ||
|
||
# Reusing regular measure as async F1 score is not implemented | ||
async def a_measure(self, test_case: LLMTestCase): | ||
return self.measure(test_case) | ||
|
||
def is_successful(self): | ||
return self.success | ||
|
||
@property | ||
def __name__(self): | ||
return "Official hotpot F1 score" | ||
|
||
class em_score_metric(BaseMetric): | ||
|
||
"""Exact Match score taken directly from the official hotpot benchmark | ||
implementation and wrapped into a deepeval metric.""" | ||
|
||
def __init__(self, threshold: float = 0.5): | ||
self.threshold = threshold | ||
|
||
def measure(self, test_case: LLMTestCase): | ||
self.score = exact_match_score( | ||
prediction=test_case.actual_output, | ||
ground_truth=test_case.expected_output, | ||
) | ||
self.success = self.score >= self.threshold | ||
return self.score | ||
|
||
# Reusing regular measure as async F1 score is not implemented | ||
async def a_measure(self, test_case: LLMTestCase): | ||
return self.measure(test_case) | ||
|
||
def is_successful(self): | ||
return self.success | ||
|
||
@property | ||
def __name__(self): | ||
return "Official hotpot EM score" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" | ||
These are the official evaluation metrics for HotpotQA taken from https://hotpotqa.github.io/ | ||
""" | ||
|
||
import re | ||
import string | ||
import sys | ||
from collections import Counter | ||
|
||
import ujson as json | ||
|
||
|
||
def normalize_answer(s): | ||
|
||
def remove_articles(text): | ||
return re.sub(r'\b(a|an|the)\b', ' ', text) | ||
|
||
def white_space_fix(text): | ||
return ' '.join(text.split()) | ||
|
||
def remove_punc(text): | ||
exclude = set(string.punctuation) | ||
return ''.join(ch for ch in text if ch not in exclude) | ||
|
||
def lower(text): | ||
return text.lower() | ||
|
||
return white_space_fix(remove_articles(remove_punc(lower(s)))) | ||
|
||
|
||
def f1_score(prediction, ground_truth): | ||
normalized_prediction = normalize_answer(prediction) | ||
normalized_ground_truth = normalize_answer(ground_truth) | ||
|
||
ZERO_METRIC = (0, 0, 0) | ||
|
||
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: | ||
return ZERO_METRIC | ||
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: | ||
return ZERO_METRIC | ||
|
||
prediction_tokens = normalized_prediction.split() | ||
ground_truth_tokens = normalized_ground_truth.split() | ||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | ||
num_same = sum(common.values()) | ||
if num_same == 0: | ||
return ZERO_METRIC | ||
precision = 1.0 * num_same / len(prediction_tokens) | ||
recall = 1.0 * num_same / len(ground_truth_tokens) | ||
f1 = (2 * precision * recall) / (precision + recall) | ||
return f1, precision, recall | ||
|
||
|
||
def exact_match_score(prediction, ground_truth): | ||
return (normalize_answer(prediction) == normalize_answer(ground_truth)) | ||
|
||
def update_answer(metrics, prediction, gold): | ||
em = exact_match_score(prediction, gold) | ||
f1, prec, recall = f1_score(prediction, gold) | ||
metrics['em'] += float(em) | ||
metrics['f1'] += f1 | ||
metrics['prec'] += prec | ||
metrics['recall'] += recall | ||
return em, prec, recall | ||
|
||
def update_sp(metrics, prediction, gold): | ||
cur_sp_pred = set(map(tuple, prediction)) | ||
gold_sp_pred = set(map(tuple, gold)) | ||
tp, fp, fn = 0, 0, 0 | ||
for e in cur_sp_pred: | ||
if e in gold_sp_pred: | ||
tp += 1 | ||
else: | ||
fp += 1 | ||
for e in gold_sp_pred: | ||
if e not in cur_sp_pred: | ||
fn += 1 | ||
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 | ||
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 | ||
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 | ||
em = 1.0 if fp + fn == 0 else 0.0 | ||
metrics['sp_em'] += em | ||
metrics['sp_f1'] += f1 | ||
metrics['sp_prec'] += prec | ||
metrics['sp_recall'] += recall | ||
return em, prec, recall |