Skip to content

Commit

Permalink
Calculate official hotpot EM and F1 scores (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
alekszievr authored Dec 10, 2024
1 parent 6d85165 commit 4f27455
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 13 deletions.
82 changes: 70 additions & 12 deletions evals/deepeval_metrics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,72 @@
from deepeval.metrics import GEval
from deepeval.test_case import LLMTestCaseParams
from deepeval.metrics import BaseMetric, GEval
from deepeval.test_case import LLMTestCase, LLMTestCaseParams

from evals.official_hotpot_metrics import exact_match_score, f1_score

correctness_metric = GEval(
name="Correctness",
model="gpt-4o-mini",
evaluation_params=[
LLMTestCaseParams.ACTUAL_OUTPUT,
LLMTestCaseParams.EXPECTED_OUTPUT
],
evaluation_steps=[
"Determine whether the actual output is factually correct based on the expected output."
]
)
name="Correctness",
model="gpt-4o-mini",
evaluation_params=[
LLMTestCaseParams.ACTUAL_OUTPUT,
LLMTestCaseParams.EXPECTED_OUTPUT
],
evaluation_steps=[
"Determine whether the actual output is factually correct based on the expected output."
]
)


class f1_score_metric(BaseMetric):

"""F1 score taken directly from the official hotpot benchmark
implementation and wrapped into a deepeval metric."""

def __init__(self, threshold: float = 0.5):
self.threshold = threshold

def measure(self, test_case: LLMTestCase):
f1, precision, recall = f1_score(
prediction=test_case.actual_output,
ground_truth=test_case.expected_output,
)
self.score = f1
self.success = self.score >= self.threshold
return self.score

# Reusing regular measure as async F1 score is not implemented
async def a_measure(self, test_case: LLMTestCase):
return self.measure(test_case)

def is_successful(self):
return self.success

@property
def __name__(self):
return "Official hotpot F1 score"

class em_score_metric(BaseMetric):

"""Exact Match score taken directly from the official hotpot benchmark
implementation and wrapped into a deepeval metric."""

def __init__(self, threshold: float = 0.5):
self.threshold = threshold

def measure(self, test_case: LLMTestCase):
self.score = exact_match_score(
prediction=test_case.actual_output,
ground_truth=test_case.expected_output,
)
self.success = self.score >= self.threshold
return self.score

# Reusing regular measure as async F1 score is not implemented
async def a_measure(self, test_case: LLMTestCase):
return self.measure(test_case)

def is_successful(self):
return self.success

@property
def __name__(self):
return "Official hotpot EM score"
6 changes: 5 additions & 1 deletion evals/llm_as_a_judge.py → evals/eval_on_hotpot.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):

parser.add_argument("--with_cognee", action="store_true")
parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument("--metric", type=str, default="correctness_metric")
parser.add_argument("--metric", type=str, default="correctness_metric",
help="Valid options are Deepeval metrics (e.g. AnswerRelevancyMetric) \
and metrics defined in evals/deepeval_metrics.py, e.g. f1_score_metric")

args = parser.parse_args()

Expand All @@ -120,6 +122,8 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
metric = metric_cls()
except AttributeError:
metric = getattr(evals.deepeval_metrics, args.metric)
if isinstance(metric, type):
metric = metric()

if args.with_cognee:
answer_provider = answer_with_cognee
Expand Down
86 changes: 86 additions & 0 deletions evals/official_hotpot_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
These are the official evaluation metrics for HotpotQA taken from https://hotpotqa.github.io/
"""

import re
import string
import sys
from collections import Counter

import ujson as json


def normalize_answer(s):

def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)

def white_space_fix(text):
return ' '.join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)

ZERO_METRIC = (0, 0, 0)

if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC

prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return ZERO_METRIC
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1, precision, recall


def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))

def update_answer(metrics, prediction, gold):
em = exact_match_score(prediction, gold)
f1, prec, recall = f1_score(prediction, gold)
metrics['em'] += float(em)
metrics['f1'] += f1
metrics['prec'] += prec
metrics['recall'] += recall
return em, prec, recall

def update_sp(metrics, prediction, gold):
cur_sp_pred = set(map(tuple, prediction))
gold_sp_pred = set(map(tuple, gold))
tp, fp, fn = 0, 0, 0
for e in cur_sp_pred:
if e in gold_sp_pred:
tp += 1
else:
fp += 1
for e in gold_sp_pred:
if e not in cur_sp_pred:
fn += 1
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
em = 1.0 if fp + fn == 0 else 0.0
metrics['sp_em'] += em
metrics['sp_f1'] += f1
metrics['sp_prec'] += prec
metrics['sp_recall'] += recall
return em, prec, recall

0 comments on commit 4f27455

Please sign in to comment.