Skip to content

Commit

Permalink
Merge pull request #22 from RUCAIBox/master
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
hyp1231 authored Jul 23, 2020
2 parents d1067f2 + ea951b3 commit 957d180
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
metric_name = {metric.lower() : metric for metric in ['Hit', 'Recall', 'MRR', 'AUC', 'Precision', 'NDCG']}

# These metrics are typical in topk recommendations
topk_metric = {'hit', 'recall', 'precision', 'ndcg'}
other_metric = {'auc', 'mrr'}
topk_metric = {'hit', 'recall', 'precision', 'ndcg', 'mrr'}
other_metric = {'auc'}

class AbstractEvaluator(metaclass=abc.ABCMeta):
"""The abstract class of the evaluation module, its subclasses must implement their functions
Expand Down
7 changes: 4 additions & 3 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@ def hit(rank, label, k):
"""
return int(any(rank[label] <= k))

def mrr(rank, label, k=None):
def mrr(rank, label, k):
"""The MRR (also known as mean reciprocal rank) is a statistic measure for evaluating any process that produces a list
of possible responses to a sample of queries, ordered by probability of correctness.
url:https://en.wikipedia.org/wiki/Mean_reciprocal_rank
"""
ground_truth_ranks = rank[label]
if ground_truth_ranks.all():
return (1 / ground_truth_ranks.min())
ground_truth_at_k = ground_truth_ranks[ground_truth_ranks <= k]
if ground_truth_at_k.shape[0] > 0:
return (1 / ground_truth_at_k.min())
return 0

def recall(rank, label, k):
Expand Down

0 comments on commit 957d180

Please sign in to comment.