diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index 28668c66b54..43bd0a5c327 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -220,8 +220,36 @@ def test_bertscore_most_similar(idf: bool, batch_size: int): assert score["f1"][i] <= score["f1"][max_target], \ f"pair: {preds[i], targets[i]} does not have a lower score than {preds[max_target], targets[max_target]}\n{i=}{max_target=}" +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.parametrize( + ["idf"], + [(False,), + (True,)], +) +def test_bertscore_most_similar_separate_calls(idf: bool): + """Tests that BERTScore actually gives the highest score to self-similarity.""" + short = "hello there" + long = "master kenobi" + longer = "general kenobi" + + sentences = [short, long, longer] + pairs_to_compare = product(sentences, + sentences) + preds, targets = list(zip(*list(product(sentences, + sentences)))) + score = {"f1": [bert_score([pred],[target], idf=idf, lang="en", + rescale_with_baseline=False)["f1"].item() + for pred, target in pairs_to_compare]} + for i in range(len(preds)): + max_pred = i%(len(sentences))*(1 + len(sentences)) + max_target = int(i/(len(sentences)))*(1 + len(sentences)) + assert score["f1"][i] <= score["f1"][max_pred], \ + f"pair: {preds[i], targets[i]} does not have a lower score than {preds[max_pred], targets[max_pred]}\n{i=}{max_pred=}" + assert score["f1"][i] <= score["f1"][max_target], \ + f"pair: {preds[i], targets[i]} does not have a lower score than {preds[max_target], targets[max_target]}\n{i=}{max_target=}" - + @skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") @pytest.mark.parametrize( @@ -251,7 +279,35 @@ def test_bertscore_symmetry(idf: bool, batch_size: int): f"f1 score for {(preds[i], targets[i])} is not the same as {(preds[j], targets[j])}." pass - +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.parametrize( + ["idf"], + [(False,), + (True,)], +) +def test_bertscore_symmetry_separate_calls(idf: bool): + """Tests that BERTscore F1 score is symmetric between reference and prediction. + As F1 is symmetric, it should also be symmetric.""" + short = "hello there" + long = "master kenobi" + longer = "general kenobi" + + sentences = [short, long, longer] + pairs_to_compare = product(sentences, + sentences) + preds, targets = list(zip(*list(product(sentences, + sentences)))) + score = {"f1": [bert_score([pred],[target], idf=idf, lang="en", + rescale_with_baseline=False)["f1"].item() + for pred, target in pairs_to_compare]} + for i in range(len(preds)): + for j in range(len(targets)): + if preds[i] == targets[j] and preds[j] == targets[i]: + assert score['f1'][i] == pytest.approx(score['f1'][j]), \ + f"f1 score for {(preds[i], targets[i])} is not the same as {(preds[j], targets[j])}." + pass + @skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") @pytest.mark.parametrize(