diff --git a/recommenders/evaluation/python_evaluation.py b/recommenders/evaluation/python_evaluation.py index a762fa10bd..3e972bc7e1 100644 --- a/recommenders/evaluation/python_evaluation.py +++ b/recommenders/evaluation/python_evaluation.py @@ -420,11 +420,11 @@ def precision_at_k( rating_pred, col_user=DEFAULT_USER_COL, col_item=DEFAULT_ITEM_COL, - col_rating=DEFAULT_RATING_COL, col_prediction=DEFAULT_PREDICTION_COL, relevancy_method="top_k", k=DEFAULT_K, threshold=DEFAULT_THRESHOLD, + **kwargs ): """Precision at K. @@ -450,7 +450,7 @@ def precision_at_k( Returns: float: precision at k (min=0, max=1) """ - + col_rating = _get_rating_column(relevancy_method, **kwargs) df_hit, df_hit_count, n_users = merge_ranking_true_pred( rating_true=rating_true, rating_pred=rating_pred, @@ -474,11 +474,11 @@ def recall_at_k( rating_pred, col_user=DEFAULT_USER_COL, col_item=DEFAULT_ITEM_COL, - col_rating=DEFAULT_RATING_COL, col_prediction=DEFAULT_PREDICTION_COL, relevancy_method="top_k", k=DEFAULT_K, threshold=DEFAULT_THRESHOLD, + **kwargs ): """Recall at K. @@ -498,7 +498,7 @@ def recall_at_k( float: recall at k (min=0, max=1). The maximum value is 1 even when fewer than k items exist for a user in rating_true. """ - + col_rating = _get_rating_column(relevancy_method, **kwargs) df_hit, df_hit_count, n_users = merge_ranking_true_pred( rating_true=rating_true, rating_pred=rating_pred, @@ -522,13 +522,13 @@ def ndcg_at_k( rating_pred, col_user=DEFAULT_USER_COL, col_item=DEFAULT_ITEM_COL, - col_rating=DEFAULT_RATING_COL, col_prediction=DEFAULT_PREDICTION_COL, relevancy_method="top_k", k=DEFAULT_K, threshold=DEFAULT_THRESHOLD, score_type="binary", discfun_type="loge", + **kwargs ): """Normalized Discounted Cumulative Gain (nDCG). @@ -553,7 +553,7 @@ def ndcg_at_k( Returns: float: nDCG at k (min=0, max=1). """ - + col_rating = _get_rating_column(relevancy_method, **kwargs) df_hit, _, _ = merge_ranking_true_pred( rating_true=rating_true, rating_pred=rating_pred, @@ -621,11 +621,11 @@ def map_at_k( rating_pred, col_user=DEFAULT_USER_COL, col_item=DEFAULT_ITEM_COL, - col_rating=DEFAULT_RATING_COL, col_prediction=DEFAULT_PREDICTION_COL, relevancy_method="top_k", k=DEFAULT_K, threshold=DEFAULT_THRESHOLD, + **kwargs ): """Mean Average Precision at k @@ -657,7 +657,7 @@ def map_at_k( Returns: float: MAP at k (min=0, max=1). """ - + col_rating = _get_rating_column(relevancy_method, **kwargs) df_hit, df_hit_count, n_users = merge_ranking_true_pred( rating_true=rating_true, rating_pred=rating_pred, @@ -736,6 +736,26 @@ def get_top_k_items( } +def _get_rating_column(relevancy_method: str, **kwargs) -> str: + r"""Helper utility to simplify the arguments of eval metrics + Attemtps to address https://github.com/microsoft/recommenders/issues/1737. + + Args: + relevancy_method (str): method for determining relevancy ['top_k', 'by_threshold', None]. None means that the + top k items are directly provided, so there is no need to compute the relevancy operation. + + Returns: + str: rating column name. + """ + if relevancy_method != "top_k": + if "col_rating" not in kwargs: + raise ValueError("Expected an argument `col_rating` but wasn't found.") + col_rating = kwargs.get("col_rating") + else: + col_rating = kwargs.get("col_rating", DEFAULT_RATING_COL) + return col_rating + + # diversity metrics def _check_column_dtypes_diversity_serendipity(func): """Checks columns of DataFrame inputs