Skip to content

Commit

Permalink
optimize ndcg
Browse files Browse the repository at this point in the history
donglihe-hub committed Jan 1, 2024

Verified

This commit was signed with the committer’s verified signature.
renovate-bot Mend Renovate
1 parent ec603a8 commit 198a58e
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions libmultilabel/nn/metrics.py
Original file line number Diff line number Diff line change
@@ -46,6 +46,8 @@ class NDCG(Metric):
Please find the formal definition here:
https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-ranked-retrieval-results-1.html
This is an optimized version of NDCG for multilabal classification. The target has to be a binary tensor.
Args:
top_k (int): the top k relevant labels to evaluate.
"""
@@ -60,18 +62,33 @@ class NDCG(Metric):
def __init__(self, top_k):
super().__init__()
self.top_k = top_k
self.add_state("ndcg", default=[], dist_reduce_fx="cat")
self.add_state("score", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum")
self.add_state("num_sample", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum")
self.add_state("discount", default=1.0 / torch.log2(torch.arange(self.top_k) + 2.0))

def update(self, preds, target):
assert preds.shape == target.shape
# implement batch-wise calculations instead of storing results of all batches
self.ndcg += [self._metric(p, t) for p, t in zip(preds, target)]
dcg = self._dcg(preds, target)
idcg = self._idcg(target)
self.score += (dcg / idcg).sum()
self.num_sample += preds.shape[0]

def compute(self):
return torch.stack(self.ndcg).mean()

def _metric(self, preds, target):
return retrieval_normalized_dcg(preds, target, k=self.top_k)
score = self.score / self.num_sample
return score

def _dcg(self, preds, target):
sorted_top_k_idx = preds.argsort(descending=True)[:, : self.top_k]
gains = target.take_along_dim(sorted_top_k_idx, 1)
dcg = (gains * self.discount).sum(1)
return dcg

def _idcg(self, target):
"""optimized idcg for multilabel classification"""
cum_discount = self.discount.cumsum(0)
idx = target.sum(1).clamp(max=self.top_k) - 1
idcg = cum_discount[idx]
return idcg


class RPrecision(Metric):

0 comments on commit 198a58e

Please sign in to comment.