Skip to content

Commit

Permalink
fix(metrics): fixed NDCG calculation and updated previous tests
Browse files Browse the repository at this point in the history
Signed-off-by: Alexey Rodriguez Yakushev <[email protected]>
  • Loading branch information
alexeyrodriguez committed Dec 10, 2024
1 parent 4ddcf78 commit 9a57dda
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -375,13 +375,13 @@ def compute(
discounted_gain(rel=docid in expected_set, i=i, mode=mode)
for i, docid in enumerate(retrieved_ids, start=1)
)

idcg = sum(
discounted_gain(rel=True, i=i, mode=mode)
for i in range(1, len(retrieved_ids) + 1)
for i in range(1, len(expected_ids) + 1)
)

ndcg_score = dcg / idcg

return RetrievalMetricResult(score=ndcg_score)


Expand Down
15 changes: 7 additions & 8 deletions llama-index-core/tests/evaluation/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,14 @@ def test_ap(expected_ids, retrieved_ids, expected_result):
["id3", "id1", "id2", "id4"],
"linear",
(1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1))
/ (1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1) + 1 / log2(4 + 1)),
/ (1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1)),
),
(
["id1", "id2", "id3", "id4"],
["id5", "id1"],
"linear",
(1 / log2(2 + 1)) / (1 / log2(1 + 1) + 1 / log2(2 + 1)),
(1 / log2(2 + 1))
/ (1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1) + 1 / log2(4 + 1)),
),
(
["id1", "id2"],
Expand All @@ -170,29 +171,27 @@ def test_ap(expected_ids, retrieved_ids, expected_result):
["id1", "id2"],
["id2", "id1", "id7"],
"linear",
(1 / log2(1 + 1) + 1 / log2(2 + 1))
/ (1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1)),
(1 / log2(1 + 1) + 1 / log2(2 + 1)) / (1 / log2(1 + 1) + 1 / log2(2 + 1)),
),
(
["id1", "id2", "id3"],
["id3", "id1", "id2", "id4"],
"exponential",
(1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1))
/ (1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1) + 1 / log2(4 + 1)),
/ (1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1)),
),
(
["id1", "id2", "id3", "id4"],
["id1", "id2", "id5"],
"exponential",
(1 / log2(1 + 1) + 1 / log2(2 + 1))
/ (1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1)),
/ (1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1) + 1 / log2(4 + 1)),
),
(
["id1", "id2"],
["id1", "id7", "id15", "id2"],
"exponential",
(1 / log2(1 + 1) + 1 / log2(4 + 1))
/ (1 / log2(1 + 1) + 1 / log2(2 + 1) + 1 / log2(3 + 1) + 1 / log2(4 + 1)),
(1 / log2(1 + 1) + 1 / log2(4 + 1)) / (1 / log2(1 + 1) + 1 / log2(2 + 1)),
),
],
)
Expand Down

0 comments on commit 9a57dda

Please sign in to comment.