diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index dfd6d60a0e5..28668c66b54 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -13,6 +13,7 @@ # limitations under the License. import os from functools import partial +from itertools import product from typing import Sequence import pytest @@ -190,4 +191,103 @@ def test_bertscore_sorting(idf: bool): score = metric(preds, targets) # First index should be the self-comparison - sorting by length should not shuffle this + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.parametrize( + ["idf", "batch_size"], + [(False, 1), + (False, 9), + (True, 1), + (True, 9)], +) +def test_bertscore_most_similar(idf: bool, batch_size: int): + """Tests that BERTScore actually gives the highest score to self-similarity.""" + short = "hello there" + long = "master kenobi" + longer = "general kenobi" + + sentences = [short, long, longer] + preds, targets = list(zip(*list(product(sentences, + sentences)))) + score = bert_score(preds, targets, idf=idf, lang="en", + rescale_with_baseline=False, batch_size=batch_size) + for i in range(len(preds)): + max_pred = i%(len(sentences))*(1 + len(sentences)) + max_target = int(i/(len(sentences)))*(1 + len(sentences)) + assert score["f1"][i] <= score["f1"][max_pred], \ + f"pair: {preds[i], targets[i]} does not have a lower score than {preds[max_pred], targets[max_pred]}\n{i=}{max_pred=}" + assert score["f1"][i] <= score["f1"][max_target], \ + f"pair: {preds[i], targets[i]} does not have a lower score than {preds[max_target], targets[max_target]}\n{i=}{max_target=}" + + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.parametrize( + ["idf", "batch_size"], + [(False, 1), + (False, 9), + (True, 1), + (True, 9)], +) +def test_bertscore_symmetry(idf: bool, batch_size: int): + """Tests that BERTscore F1 score is symmetric between reference and prediction. + As F1 is symmetric, it should also be symmetric.""" + + short = "hello there" + long = "master kenobi" + longer = "general kenobi" + + sentences = [short, long, longer] + preds, targets = list(zip(*list(product(sentences, + sentences)))) + score = bert_score(preds, targets, idf=idf, lang="en", + rescale_with_baseline=False, batch_size=batch_size) + for i in range(len(preds)): + for j in range(len(targets)): + if preds[i] == targets[j] and preds[j] == targets[i]: + assert score['f1'][i] == pytest.approx(score['f1'][j]), \ + f"f1 score for {(preds[i], targets[i])} is not the same as {(preds[j], targets[j])}." + pass + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.parametrize( + ["idf", "batch_size"], + [(False, 1), + (False, 3)] +) +def test_bertscore_additional_sentence(idf: bool, batch_size: int): + """Tests that BERTscore keeps the same scores for previous inputs + by adding additional elements to the input lists. This should be the case for idf=False.""" + + short = "hello there" + long = "master kenobi" + longer = "general kenobi" + + preds = [long,long] + targets = [long,short] + + score = bert_score(preds, targets, idf=idf, lang="en", + rescale_with_baseline=False, batch_size=batch_size) + + longlong = score["f1"][0] + longshort = score["f1"][1] + # First index should be the self-comparison - sorting by length should not shuffle this + assert longlong > longshort + + preds = preds + [short, longer] + targets = targets + [longer, long] + + score = bert_score(preds, targets, idf=idf, lang="en", + rescale_with_baseline=False, batch_size=batch_size) + + # First two indices should be exactly as in the previous call to metric + assert score["f1"][0] == pytest.approx(longlong) + assert score["f1"][1] == pytest.approx(longshort) + # Indices 1 and 2 should also be smaller than self-comparison. assert score["f1"][0] > score["f1"][1] + assert score["f1"][0] > score["f1"][2] +