-
Notifications
You must be signed in to change notification settings - Fork 627
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #867 from guijiql/evaluator
FEA: add two new metrics
- Loading branch information
Showing
4 changed files
with
117 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
# @email : [email protected] | ||
|
||
# UPDATE | ||
# @Time : 2020/08/12, 2021/6/25, 2020/9/16, 2021/7/2 | ||
# @Time : 2020/08/12, 2021/7/5, 2020/9/16, 2021/7/2 | ||
# @Author : Kaiyuan Li, Zhichao Feng, Xingyu Pan, Zihan Lin | ||
# @email : [email protected], [email protected], [email protected], [email protected] | ||
|
||
|
@@ -151,6 +151,7 @@ class NDCG(TopkMetric): | |
\mathrm {NDCG_u@K}=\frac{DCG_u@K}{IDCG_u@K}\\ | ||
\mathrm {NDCG@K}=\frac{\sum \nolimits_{u \in U^{te}NDCG_u@K}}{|U^{te}|} | ||
\end{gather} | ||
:math:`K` stands for recommending :math:`K` items. | ||
And the :math:`rel_i` is the relevance of the item in position :math:`i` in the recommendation list. | ||
:math:`{rel_i}` equals to 1 if the item is ground truth otherwise 0. | ||
|
@@ -387,6 +388,46 @@ def metric_info(self, preds, trues): | |
return loss / len(preds) | ||
|
||
|
||
class ItemCoverage(object): | ||
r"""It computes the coverage of recommended items over all items. | ||
For further details, please refer to the `paper <https://dl.acm.org/doi/10.1145/1864708.1864761>` and | ||
`paper <https://link.springer.com/article/10.1007/s13042-017-0762-9>`_ | ||
.. math:: | ||
\mathrm{Coverage}=\frac{\left| \bigcup_{u \in U} \hat{R}(u) \right|}{|I|} | ||
:math:`U` is total user set. | ||
:math:`R_{u}` is the recommended list of items for user u. | ||
:math:`I` is total item set. | ||
""" | ||
|
||
def __init__(self, config): | ||
self.topk = config['topk'] | ||
self.decimal_place = config['metric_decimal_place'] | ||
|
||
def used_info(self, dataobject): | ||
"""get the matrix of recommendation items and number of items in total item set""" | ||
item_matrix = dataobject.get('rec.items') | ||
num_items = dataobject.get('data.num_items') | ||
return item_matrix.numpy(), num_items | ||
|
||
def calculate_metric(self, dataobject): | ||
item_matrix, num_items = self.used_info(dataobject) | ||
metric_dict = {} | ||
for k in self.topk: | ||
key = '{}@{}'.format('itemcoverage', k) | ||
metric_dict[key] = round(self.get_coverage(item_matrix[:, :k], num_items), self.decimal_place) | ||
return metric_dict | ||
|
||
def get_coverage(self, item_matrix, num_items): | ||
"""get the coverage of recommended items over all items""" | ||
unique_count = np.unique(item_matrix).shape[0] | ||
return unique_count / num_items | ||
|
||
|
||
class AveragePopularity: | ||
r"""It computes the average popularity of recommended items. | ||
|
@@ -485,6 +526,50 @@ def get_entropy(self, item_matrix): | |
return result/len(item_count) | ||
|
||
|
||
class GiniIndex(object): | ||
r"""This metric present the diversity of the recommendation items. | ||
It is used to measure the inequality of a distribution. | ||
For further details, please refer to the `paper <https://dl.acm.org/doi/10.1145/3308560.3317303>` | ||
.. math:: | ||
\mathrm {GiniIndex}=\left(\frac{\sum_{i=1}^{n}(2 i-n-1) P_{(i)}}{n \sum_{i=1}^{n} P_{(i)}}\right) | ||
:math:`n` is the number of all items. | ||
:math:`P_{(i)}` is the number of each item in recommended list, | ||
which is indexed in non-decreasing order (P_{(i)} \leq P_{(i+1)}). | ||
""" | ||
|
||
def __init__(self, config): | ||
self.topk = config['topk'] | ||
self.decimal_place = config['metric_decimal_place'] | ||
|
||
def used_info(self, dataobject): | ||
"""get the matrix of recommendation items and number of items in total item set""" | ||
item_matrix = dataobject.get('rec.items') | ||
num_items = dataobject.get('data.num_items') | ||
return item_matrix.numpy(), num_items | ||
|
||
def calculate_metric(self, dataobject): | ||
item_matrix, num_items = self.used_info(dataobject) | ||
metric_dict = {} | ||
for k in self.topk: | ||
key = '{}@{}'.format('giniindex', k) | ||
metric_dict[key] = round(self.get_gini(item_matrix[:, :k], num_items), self.decimal_place) | ||
return metric_dict | ||
|
||
def get_gini(self, item_matrix, num_items): | ||
"""get gini index""" | ||
item_count = dict(Counter(item_matrix.flatten())) | ||
sorted_count = np.array(sorted(item_count.values())) | ||
num_recommended_items = sorted_count.shape[0] | ||
total_num = item_matrix.shape[0] * item_matrix.shape[1] | ||
idx = np.arange(num_items - num_recommended_items + 1, num_items + 1) | ||
gini_index = np.sum((2 * idx - num_items - 1) * sorted_count) / total_num | ||
gini_index /= num_items | ||
return gini_index | ||
|
||
|
||
metrics_dict = { | ||
'ndcg': NDCG, | ||
'hit': Hit, | ||
|
@@ -497,6 +582,8 @@ def get_entropy(self, item_matrix): | |
'logloss': LogLoss, | ||
'auc': AUC, | ||
'gauc': GAUC, | ||
'itemcoverage': ItemCoverage, | ||
'averagepopularity': AveragePopularity, | ||
'giniindex': GiniIndex, | ||
'shannonentropy': ShannonEntropy | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,11 @@ | |
# @Author : Zihan Lin | ||
# @Email : [email protected] | ||
|
||
# UPDATE | ||
# @Time : 2021/7/5 | ||
# @Author : Zhichao Feng | ||
# @email : [email protected] | ||
|
||
""" | ||
recbole.evaluator.register | ||
################################################ | ||
|
@@ -31,7 +36,7 @@ | |
'logloss': ['rec.score', 'data.label']} | ||
# These metrics are typical in top-k recommendations | ||
topk_metrics = {metric.lower(): metric for metric in ['Hit', 'Recall', 'MRR', 'Precision', 'NDCG', 'MAP', | ||
'AveragePopularity', 'ShannonEntropy']} | ||
'ItemCoverage', 'AveragePopularity', 'ShannonEntropy', 'GiniIndex']} | ||
# These metrics are typical in loss recommendations | ||
loss_metrics = {metric.lower(): metric for metric in ['AUC', 'RMSE', 'MAE', 'LOGLOSS']} | ||
# For GAUC | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,9 @@ | |
# @email : [email protected] | ||
|
||
# UPDATE | ||
# @Time : 2021/7/2 | ||
# @Author : Zihan Lin | ||
# @email : [email protected] | ||
# @Time : 2021/7/2, 2021/7/5 | ||
# @Author : Zihan Lin, Zhichao Feng | ||
# @email : [email protected], [email protected] | ||
|
||
import os | ||
import sys | ||
|
@@ -39,6 +39,8 @@ | |
[5, 3, 7] | ||
]) | ||
|
||
num_items = 8 | ||
|
||
item_count = {1: 0, | ||
2: 1, | ||
3: 2, | ||
|
@@ -105,6 +107,13 @@ def test_precision(self): | |
np.array([[0, 0, 0], [1 / 1, 2 / 2, 3 / 3], [1 / 1, 1 / 2, 2 / 3], | ||
[0, 0, 1 / 3]]).tolist()) | ||
|
||
def test_itemcoverage(self): | ||
name = 'itemcoverage' | ||
Metric = metrics_dict[name](config) | ||
self.assertEqual( | ||
Metric.get_coverage(item_matrix, num_items), | ||
7 / 8) | ||
|
||
def test_averagepopularity(self): | ||
name = 'averagepopularity' | ||
Metric = metrics_dict[name](config) | ||
|
@@ -113,13 +122,21 @@ def test_averagepopularity(self): | |
np.array([[4/1, 4/2, 6/3], [3/1, 7/2, 8/3], [1/1, 3/2, 7/3], [0/1, 3/2, 8/3], | ||
[4/1, 6/2, 6/3]]).tolist()) | ||
|
||
def test_ShannonEntropy(self): | ||
def test_giniindex(self): | ||
name = 'giniindex' | ||
Metric = metrics_dict[name](config) | ||
self.assertEqual( | ||
Metric.get_gini(item_matrix, num_items), | ||
((-7) * 0 + (-5) * 1 + (-3) * 1 + (-1) * 2 + 1 * 2 + 3 * 2 + 5 * 3 + 7 * 4) | ||
/ (8 * (3 * 5))) | ||
|
||
def test_shannonentropy(self): | ||
name = 'shannonentropy' | ||
Metric = metrics_dict[name](config) | ||
self.assertEqual( | ||
Metric.get_entropy(item_matrix), | ||
-np.mean([1/15*np.log(1/15), 2/15*np.log(2/15), 3/15*np.log(3/15), 2/15*np.log(2/15), | ||
4/15*np.log(4/15), 1/15*np.log(1/15), 2/15*np.log(2/15)])) | ||
4/15*np.log(4/15), 1/15*np.log(1/15), 2/15*np.log(2/15)])) | ||
|
||
|
||
if __name__ == "__main__": | ||
|