diff --git a/evals/deepeval_metrics.py b/evals/deepeval_metrics.py index 03f9f6dba..b07d2e1ac 100644 --- a/evals/deepeval_metrics.py +++ b/evals/deepeval_metrics.py @@ -1,14 +1,72 @@ -from deepeval.metrics import GEval -from deepeval.test_case import LLMTestCaseParams +from deepeval.metrics import BaseMetric, GEval +from deepeval.test_case import LLMTestCase, LLMTestCaseParams + +from evals.official_hotpot_metrics import exact_match_score, f1_score correctness_metric = GEval( - name="Correctness", - model="gpt-4o-mini", - evaluation_params=[ - LLMTestCaseParams.ACTUAL_OUTPUT, - LLMTestCaseParams.EXPECTED_OUTPUT - ], - evaluation_steps=[ - "Determine whether the actual output is factually correct based on the expected output." - ] - ) + name="Correctness", + model="gpt-4o-mini", + evaluation_params=[ + LLMTestCaseParams.ACTUAL_OUTPUT, + LLMTestCaseParams.EXPECTED_OUTPUT + ], + evaluation_steps=[ + "Determine whether the actual output is factually correct based on the expected output." + ] +) + + +class f1_score_metric(BaseMetric): + + """F1 score taken directly from the official hotpot benchmark + implementation and wrapped into a deepeval metric.""" + + def __init__(self, threshold: float = 0.5): + self.threshold = threshold + + def measure(self, test_case: LLMTestCase): + f1, precision, recall = f1_score( + prediction=test_case.actual_output, + ground_truth=test_case.expected_output, + ) + self.score = f1 + self.success = self.score >= self.threshold + return self.score + + # Reusing regular measure as async F1 score is not implemented + async def a_measure(self, test_case: LLMTestCase): + return self.measure(test_case) + + def is_successful(self): + return self.success + + @property + def __name__(self): + return "Official hotpot F1 score" + +class em_score_metric(BaseMetric): + + """Exact Match score taken directly from the official hotpot benchmark + implementation and wrapped into a deepeval metric.""" + + def __init__(self, threshold: float = 0.5): + self.threshold = threshold + + def measure(self, test_case: LLMTestCase): + self.score = exact_match_score( + prediction=test_case.actual_output, + ground_truth=test_case.expected_output, + ) + self.success = self.score >= self.threshold + return self.score + + # Reusing regular measure as async F1 score is not implemented + async def a_measure(self, test_case: LLMTestCase): + return self.measure(test_case) + + def is_successful(self): + return self.success + + @property + def __name__(self): + return "Official hotpot EM score" \ No newline at end of file diff --git a/evals/llm_as_a_judge.py b/evals/eval_on_hotpot.py similarity index 93% rename from evals/llm_as_a_judge.py rename to evals/eval_on_hotpot.py index 4deee7d7d..e07e80e0c 100644 --- a/evals/llm_as_a_judge.py +++ b/evals/eval_on_hotpot.py @@ -111,7 +111,9 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric): parser.add_argument("--with_cognee", action="store_true") parser.add_argument("--num_samples", type=int, default=500) - parser.add_argument("--metric", type=str, default="correctness_metric") + parser.add_argument("--metric", type=str, default="correctness_metric", + help="Valid options are Deepeval metrics (e.g. AnswerRelevancyMetric) \ + and metrics defined in evals/deepeval_metrics.py, e.g. f1_score_metric") args = parser.parse_args() @@ -120,6 +122,8 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric): metric = metric_cls() except AttributeError: metric = getattr(evals.deepeval_metrics, args.metric) + if isinstance(metric, type): + metric = metric() if args.with_cognee: answer_provider = answer_with_cognee diff --git a/evals/official_hotpot_metrics.py b/evals/official_hotpot_metrics.py new file mode 100644 index 000000000..b598e90d3 --- /dev/null +++ b/evals/official_hotpot_metrics.py @@ -0,0 +1,86 @@ +""" +These are the official evaluation metrics for HotpotQA taken from https://hotpotqa.github.io/ +""" + +import re +import string +import sys +from collections import Counter + +import ujson as json + + +def normalize_answer(s): + + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + ZERO_METRIC = (0, 0, 0) + + if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: + return ZERO_METRIC + if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: + return ZERO_METRIC + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return ZERO_METRIC + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, precision, recall + + +def exact_match_score(prediction, ground_truth): + return (normalize_answer(prediction) == normalize_answer(ground_truth)) + +def update_answer(metrics, prediction, gold): + em = exact_match_score(prediction, gold) + f1, prec, recall = f1_score(prediction, gold) + metrics['em'] += float(em) + metrics['f1'] += f1 + metrics['prec'] += prec + metrics['recall'] += recall + return em, prec, recall + +def update_sp(metrics, prediction, gold): + cur_sp_pred = set(map(tuple, prediction)) + gold_sp_pred = set(map(tuple, gold)) + tp, fp, fn = 0, 0, 0 + for e in cur_sp_pred: + if e in gold_sp_pred: + tp += 1 + else: + fp += 1 + for e in gold_sp_pred: + if e not in cur_sp_pred: + fn += 1 + prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 + recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 + f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 + em = 1.0 if fp + fn == 0 else 0.0 + metrics['sp_em'] += em + metrics['sp_f1'] += f1 + metrics['sp_prec'] += prec + metrics['sp_recall'] += recall + return em, prec, recall