diff --git a/README.md b/README.md
index f5419bf26b..782b50d283 100644
--- a/README.md
+++ b/README.md
@@ -110,7 +110,7 @@ We provide a [benchmark notebook](examples/06_benchmarks/movielens.ipynb) to ill
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| [ALS](examples/00_quick_start/als_movielens.ipynb) | 0.004732 | 0.044239 | 0.048462 | 0.017796 | 0.965038 | 0.753001 | 0.255647 | 0.251648 |
| [SVD](examples/02_model_collaborative_filtering/surprise_svd_deep_dive.ipynb) | 0.012873 | 0.095930 | 0.091198 | 0.032783 | 0.938681 | 0.742690 | 0.291967 | 0.291971 |
-| [SAR](examples/00_quick_start/sar_movielens.ipynb) | 0.113028 | 0.388321 | 0.333828 | 0.183179 | N/A | N/A | N/A | N/A |
+| [SAR](examples/00_quick_start/sar_movielens.ipynb) | 0.110591 | 0.382461 | 0.330753 | 0.176385 | 1.253805 | 1.048484 | -0.569363 | 0.030474 |
| [NCF](examples/02_model_hybrid/ncf_deep_dive.ipynb) | 0.107720 | 0.396118 | 0.347296 | 0.180775 | N/A | N/A | N/A | N/A |
| [BPR](examples/02_model_collaborative_filtering/cornac_bpr_deep_dive.ipynb) | 0.105365 | 0.389948 | 0.349841 | 0.181807 | N/A | N/A | N/A | N/A |
| [FastAI](examples/00_quick_start/fastai_movielens.ipynb) | 0.025503 | 0.147866 | 0.130329 | 0.053824 | 0.943084 | 0.744337 | 0.285308 | 0.287671 |
diff --git a/examples/00_quick_start/sar_movielens.ipynb b/examples/00_quick_start/sar_movielens.ipynb
index 14b5305fed..580cd99950 100644
--- a/examples/00_quick_start/sar_movielens.ipynb
+++ b/examples/00_quick_start/sar_movielens.ipynb
@@ -41,20 +41,23 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "System version: 3.7.3 | packaged by conda-forge | (default, Jul 1 2019, 21:52:21) \n",
- "[GCC 7.3.0]\n",
- "Pandas version: 0.23.4\n"
+ "System version: 3.6.10 |Anaconda, Inc.| (default, May 7 2020, 23:06:31) \n",
+ "[GCC 4.2.1 Compatible Clang 4.0.1 (tags/RELEASE_401/final)]\n",
+ "Pandas version: 0.25.3\n"
]
}
],
"source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
"# set the environment path to find Recommenders\n",
"import sys\n",
"sys.path.append(\"../../\")\n",
@@ -63,11 +66,23 @@
"import numpy as np\n",
"import pandas as pd\n",
"import papermill as pm\n",
+ "from sklearn.preprocessing import minmax_scale\n",
"\n",
+ "from reco_utils.common.python_utils import binarize\n",
"from reco_utils.common.timer import Timer\n",
"from reco_utils.dataset import movielens\n",
"from reco_utils.dataset.python_splitters import python_stratified_split\n",
- "from reco_utils.evaluation.python_evaluation import map_at_k, ndcg_at_k, precision_at_k, recall_at_k\n",
+ "from reco_utils.evaluation.python_evaluation import (\n",
+ " map_at_k,\n",
+ " ndcg_at_k,\n",
+ " precision_at_k,\n",
+ " recall_at_k,\n",
+ " rmse,\n",
+ " mae,\n",
+ " logloss,\n",
+ " rsquared,\n",
+ " exp_var\n",
+ ")\n",
"from reco_utils.recommender.sar import SAR\n",
"\n",
"print(\"System version: {}\".format(sys.version))\n",
@@ -90,7 +105,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"metadata": {
"tags": [
"parameters"
@@ -114,93 +129,22 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|██████████| 4.81k/4.81k [00:02<00:00, 1.90kKB/s]\n"
+ "100%|██████████| 4.81k/4.81k [00:02<00:00, 1.67kKB/s]\n"
]
},
{
"data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " userID | \n",
- " itemID | \n",
- " rating | \n",
- " timestamp | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 196 | \n",
- " 242 | \n",
- " 3.0 | \n",
- " 881250949 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 186 | \n",
- " 302 | \n",
- " 3.0 | \n",
- " 891717742 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 22 | \n",
- " 377 | \n",
- " 1.0 | \n",
- " 878887116 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 244 | \n",
- " 51 | \n",
- " 2.0 | \n",
- " 880606923 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 166 | \n",
- " 346 | \n",
- " 1.0 | \n",
- " 886397596 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " userID itemID rating timestamp\n",
- "0 196 242 3.0 881250949\n",
- "1 186 302 3.0 891717742\n",
- "2 22 377 1.0 878887116\n",
- "3 244 51 2.0 880606923\n",
- "4 166 346 1.0 886397596"
- ]
+ "text/plain": " userID itemID rating timestamp\n0 196 242 3.0 881250949\n1 186 302 3.0 891717742\n2 22 377 1.0 878887116\n3 244 51 2.0 880606923\n4 166 346 1.0 886397596",
+ "text/html": "\n\n
\n \n \n | \n userID | \n itemID | \n rating | \n timestamp | \n
\n \n \n \n 0 | \n 196 | \n 242 | \n 3.0 | \n 881250949 | \n
\n \n 1 | \n 186 | \n 302 | \n 3.0 | \n 891717742 | \n
\n \n 2 | \n 22 | \n 377 | \n 1.0 | \n 878887116 | \n
\n \n 3 | \n 244 | \n 51 | \n 2.0 | \n 880606923 | \n
\n \n 4 | \n 166 | \n 346 | \n 1.0 | \n 886397596 | \n
\n \n
\n
"
},
- "execution_count": 4,
+ "execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -227,7 +171,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -236,7 +180,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -298,7 +242,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -312,7 +256,8 @@
" col_timestamp=\"timestamp\",\n",
" similarity_type=\"jaccard\", \n",
" time_decay_coefficient=30, \n",
- " timedecay_formula=True\n",
+ " timedecay_formula=True,\n",
+ " normalize=True\n",
")"
]
},
@@ -333,14 +278,29 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 7,
"metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2020-08-18 11:31:24,687 INFO Collecting user affinity matrix\n",
+ "2020-08-18 11:31:24,690 INFO Calculating time-decayed affinities\n",
+ "2020-08-18 11:31:24,722 INFO Creating index columns\n",
+ "2020-08-18 11:31:24,822 INFO Calculating normalization factors\n",
+ "2020-08-18 11:31:24,862 INFO Building user affinity sparse matrix\n",
+ "2020-08-18 11:31:24,868 INFO Calculating item co-occurrence\n",
+ "2020-08-18 11:31:25,063 INFO Calculating item similarity\n",
+ "2020-08-18 11:31:25,063 INFO Using jaccard based similarity\n",
+ "2020-08-18 11:31:25,158 INFO Done training\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Took 0.3302565817721188 seconds for training.\n"
+ "Took 0.4742812399927061 seconds for training.\n"
]
}
],
@@ -353,14 +313,22 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 8,
"metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2020-08-18 11:31:25,202 INFO Calculating recommendation scores\n",
+ "2020-08-18 11:31:25,389 INFO Removing seen items\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Took 0.21034361701458693 seconds for prediction.\n"
+ "Took 0.22830537598929368 seconds for prediction.\n"
]
}
],
@@ -373,82 +341,17 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " userID | \n",
- " itemID | \n",
- " prediction | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 1 | \n",
- " 204 | \n",
- " 3.313306 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 1 | \n",
- " 89 | \n",
- " 3.280465 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 1 | \n",
- " 11 | \n",
- " 3.233867 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 1 | \n",
- " 367 | \n",
- " 3.192575 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 1 | \n",
- " 423 | \n",
- " 3.131517 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " userID itemID prediction\n",
- "0 1 204 3.313306\n",
- "1 1 89 3.280465\n",
- "2 1 11 3.233867\n",
- "3 1 367 3.192575\n",
- "4 1 423 3.131517"
- ]
+ "text/plain": " userID itemID prediction\n0 1 204 3.231405\n1 1 89 3.199445\n2 1 11 3.154097\n3 1 367 3.113913\n4 1 423 3.054493",
+ "text/html": "\n\n
\n \n \n | \n userID | \n itemID | \n prediction | \n
\n \n \n \n 0 | \n 1 | \n 204 | \n 3.231405 | \n
\n \n 1 | \n 1 | \n 89 | \n 3.199445 | \n
\n \n 2 | \n 1 | \n 11 | \n 3.154097 | \n
\n \n 3 | \n 1 | \n 367 | \n 3.113913 | \n
\n \n 4 | \n 1 | \n 423 | \n 3.054493 | \n
\n \n
\n
"
},
- "execution_count": 10,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -468,7 +371,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -477,7 +380,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -486,7 +389,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -495,17 +398,95 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"eval_recall = recall_at_k(test, top_k, col_user='userID', col_item='itemID', col_rating='rating', k=TOP_K)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "eval_rmse = rmse(test, top_k, col_user='userID', col_item='itemID', col_rating='rating')"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 15,
- "metadata": {},
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "eval_mae = mae(test, top_k, col_user='userID', col_item='itemID', col_rating='rating')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "eval_rsquared = rsquared(test, top_k, col_user='userID', col_item='itemID', col_rating='rating')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "eval_exp_var = exp_var(test, top_k, col_user='userID', col_item='itemID', col_rating='rating')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "positivity_threshold = 2\n",
+ "test_bin = test.copy()\n",
+ "test_bin['rating'] = binarize(test_bin['rating'], positivity_threshold)\n",
+ "\n",
+ "top_k_prob = top_k.copy()\n",
+ "top_k_prob['prediction'] = minmax_scale(\n",
+ " top_k_prob['prediction'].astype(float)\n",
+ ")\n",
+ "\n",
+ "eval_logloss = logloss(test_bin, top_k_prob, col_user='userID', col_item='itemID', col_rating='rating')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
"outputs": [
{
"name": "stdout",
@@ -516,7 +497,12 @@
"MAP:\t0.110591\n",
"NDCG:\t0.382461\n",
"Precision@K:\t0.330753\n",
- "Recall@K:\t0.176385\n"
+ "Recall@K:\t0.176385\n",
+ "RMSE:\t1.253805\n",
+ "MAE:\t1.048484\n",
+ "R2:\t-0.569363\n",
+ "Exp var:\t0.030474\n",
+ "Logloss:\t0.542861\n"
]
}
],
@@ -526,97 +512,38 @@
" \"MAP:\\t%f\" % eval_map,\n",
" \"NDCG:\\t%f\" % eval_ndcg,\n",
" \"Precision@K:\\t%f\" % eval_precision,\n",
- " \"Recall@K:\\t%f\" % eval_recall, sep='\\n')"
+ " \"Recall@K:\\t%f\" % eval_recall,\n",
+ " \"RMSE:\\t%f\" % eval_rmse,\n",
+ " \"MAE:\\t%f\" % eval_mae,\n",
+ " \"R2:\\t%f\" % eval_rsquared,\n",
+ " \"Exp var:\\t%f\" % eval_exp_var,\n",
+ " \"Logloss:\\t%f\" % eval_logloss,\n",
+ " sep='\\n')"
]
},
{
"cell_type": "code",
- "execution_count": 16,
- "metadata": {},
+ "execution_count": 23,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2020-08-18 11:33:07,631 INFO Calculating recommendation scores\n",
+ "2020-08-18 11:33:07,643 INFO Removing seen items\n"
+ ]
+ },
{
"data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " userID | \n",
- " itemID | \n",
- " rating | \n",
- " timestamp | \n",
- " prediction | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 876 | \n",
- " 523 | \n",
- " 5.0 | \n",
- " 879428378 | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 876 | \n",
- " 529 | \n",
- " 4.0 | \n",
- " 879428451 | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 876 | \n",
- " 174 | \n",
- " 4.0 | \n",
- " 879428378 | \n",
- " 0.353567 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 876 | \n",
- " 276 | \n",
- " 4.0 | \n",
- " 879428354 | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 876 | \n",
- " 288 | \n",
- " 3.0 | \n",
- " 879428101 | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " userID itemID rating timestamp prediction\n",
- "0 876 523 5.0 879428378 NaN\n",
- "1 876 529 4.0 879428451 NaN\n",
- "2 876 174 4.0 879428378 0.353567\n",
- "3 876 276 4.0 879428354 NaN\n",
- "4 876 288 3.0 879428101 NaN"
- ]
+ "text/plain": " userID itemID rating timestamp prediction\n0 876 523 5.0 879428378 NaN\n1 876 529 4.0 879428451 NaN\n2 876 174 4.0 879428378 3.702239\n3 876 276 4.0 879428354 NaN\n4 876 288 3.0 879428101 NaN",
+ "text/html": "\n\n
\n \n \n | \n userID | \n itemID | \n rating | \n timestamp | \n prediction | \n
\n \n \n \n 0 | \n 876 | \n 523 | \n 5.0 | \n 879428378 | \n NaN | \n
\n \n 1 | \n 876 | \n 529 | \n 4.0 | \n 879428451 | \n NaN | \n
\n \n 2 | \n 876 | \n 174 | \n 4.0 | \n 879428378 | \n 3.702239 | \n
\n \n 3 | \n 876 | \n 276 | \n 4.0 | \n 879428354 | \n NaN | \n
\n \n 4 | \n 876 | \n 288 | \n 3.0 | \n 879428101 | \n NaN | \n
\n \n
\n
"
},
- "execution_count": 16,
+ "execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
@@ -639,64 +566,13 @@
},
{
"cell_type": "code",
- "execution_count": 86,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/papermill.record+json": {
- "map": 0.11059057578638949
- }
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/papermill.record+json": {
- "ndcg": 0.3824612290501957
- }
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/papermill.record+json": {
- "precision": 0.33075291622481445
- }
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/papermill.record+json": {
- "recall": 0.1763854474342893
- }
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/papermill.record+json": {
- "train_time": 0.284792423248291
- }
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/papermill.record+json": {
- "test_time": 0.1463017463684082
- }
- },
- "metadata": {},
- "output_type": "display_data"
+ "execution_count": null,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
}
- ],
+ },
+ "outputs": [],
"source": [
"# Record results with papermill for tests - ignore this cell\n",
"pm.record(\"map\", eval_map)\n",
@@ -711,9 +587,9 @@
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
- "display_name": "Python 3",
+ "name": "python3",
"language": "python",
- "name": "python3"
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -725,9 +601,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.3"
+ "version": "3.6.10"
}
},
"nbformat": 4,
- "nbformat_minor": 2
-}
+ "nbformat_minor": 4
+}
\ No newline at end of file
diff --git a/reco_utils/__init__.py b/reco_utils/__init__.py
index df052a7fad..3103749f97 100644
--- a/reco_utils/__init__.py
+++ b/reco_utils/__init__.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
__title__ = "Microsoft Recommenders"
-__version__ = "2019.9"
+__version__ = "2020.8"
__author__ = "RecoDev Team at Microsoft"
__license__ = "MIT"
__copyright__ = "Copyright 2018-present Microsoft Corporation"
diff --git a/reco_utils/common/python_utils.py b/reco_utils/common/python_utils.py
index a365618ea3..3f10d0d409 100644
--- a/reco_utils/common/python_utils.py
+++ b/reco_utils/common/python_utils.py
@@ -116,3 +116,26 @@ def binarize(a, threshold):
0.0
)
+
+def rescale(data, new_min=0, new_max=1, data_min=None, data_max=None):
+ """
+ Rescale/normalize the data to be within the range [new_min, new_max]
+ If data_min and data_max are explicitly provided, they will be used
+ as the old min/max values instead of taken from the data.
+
+ Note: this is same as the scipy.MinMaxScaler with the exception that we can override
+ the min/max of the old scale.
+
+ Args:
+ data (np.array): 1d scores vector or 2d score matrix (users x items).
+ new_min (int|float): The minimum of the newly scaled data.
+ new_max (int|float): The maximum of the newly scaled data.
+ data_min (None|number): The minimum of the passed data [if omitted it will be inferred].
+ data_max (None|number): The maximum of the passed data [if omitted it will be inferred].
+
+ Returns:
+ np.array: The newly scaled/normalized data.
+ """
+ data_min = data.min() if data_min is None else data_min
+ data_max = data.max() if data_max is None else data_max
+ return (data - data_min) / (data_max - data_min) * (new_max - new_min) + new_min
diff --git a/reco_utils/recommender/sar/sar_singlenode.py b/reco_utils/recommender/sar/sar_singlenode.py
index 46685a3d3f..298c990eda 100644
--- a/reco_utils/recommender/sar/sar_singlenode.py
+++ b/reco_utils/recommender/sar/sar_singlenode.py
@@ -12,6 +12,7 @@
lift,
exponential_decay,
get_top_k_scored_items,
+ rescale,
)
from reco_utils.common import constants
@@ -89,7 +90,6 @@ def __init__(
# set flag to capture unity-rating user-affinity matrix for scaling scores
self.normalize = normalize
self.col_unity_rating = "_unity_rating"
- self.unity_user_affinity = None
# column for mapping user / item ids to internal indices
self.col_item_id = "_indexed_items"
@@ -99,6 +99,10 @@ def __init__(
self.n_users = None
self.n_items = None
+ # The min and max of the rating scale, obtained from the training data.
+ self.rating_min = None
+ self.rating_max = None
+
# mapping for item to matrix element
self.user2index = None
self.item2index = None
@@ -239,6 +243,8 @@ def fit(self, df):
)
if self.normalize:
+ self.rating_min = temp_df[self.col_rating].min()
+ self.rating_max = temp_df[self.col_rating].max()
logger.info("Calculating normalization factors")
temp_df[self.col_unity_rating] = 1.0
if self.time_decay_flag:
@@ -286,13 +292,12 @@ def fit(self, df):
logger.info("Done training")
- def score(self, test, remove_seen=False, normalize=False):
+ def score(self, test, remove_seen=False):
"""Score all items for test users.
Args:
test (pd.DataFrame): user to test
remove_seen (bool): flag to remove items seen in training from recommendation
- normalize (bool): flag to normalize scores to be in the same scale as the original ratings
Returns:
np.ndarray: Value of interest of all items for the users.
@@ -316,25 +321,29 @@ def score(self, test, remove_seen=False, normalize=False):
if isinstance(test_scores, sparse.spmatrix):
test_scores = test_scores.toarray()
+ if self.normalize:
+ counts = self.unity_user_affinity[user_ids, :].dot(self.item_similarity)
+ user_min_scores = (
+ np.tile(counts.min(axis=1)[:, np.newaxis], test_scores.shape[1])
+ * self.rating_min
+ )
+ user_max_scores = (
+ np.tile(counts.max(axis=1)[:, np.newaxis], test_scores.shape[1])
+ * self.rating_max
+ )
+ test_scores = rescale(
+ test_scores,
+ self.rating_min,
+ self.rating_max,
+ user_min_scores,
+ user_max_scores,
+ )
+
# remove items in the train set so recommended items are always novel
if remove_seen:
logger.info("Removing seen items")
test_scores += self.user_affinity[user_ids, :] * -np.inf
- if normalize:
- if self.unity_user_affinity is None:
- raise ValueError(
- "Cannot use normalize flag during scoring if it was not set at model instantiation"
- )
- else:
- test_scores = np.array(
- np.divide(
- test_scores,
- self.unity_user_affinity[user_ids, :].dot(self.item_similarity),
- )
- )
- test_scores = np.where(np.isnan(test_scores), -np.inf, test_scores)
-
return test_scores
def get_popularity_based_topk(self, top_k=10, sort_top_k=True):
@@ -438,9 +447,7 @@ def get_item_based_topk(self, items, top_k=10, sort_top_k=True):
# drop invalid items
return df.replace(-np.inf, np.nan).dropna()
- def recommend_k_items(
- self, test, top_k=10, sort_top_k=True, remove_seen=False, normalize=False
- ):
+ def recommend_k_items(self, test, top_k=10, sort_top_k=True, remove_seen=False):
"""Recommend top K items for all users which are in the test set
Args:
@@ -453,7 +460,7 @@ def recommend_k_items(
pd.DataFrame: top k recommendation items for each user
"""
- test_scores = self.score(test, remove_seen=remove_seen, normalize=normalize)
+ test_scores = self.score(test, remove_seen=remove_seen)
top_items, top_scores = get_top_k_scored_items(
scores=test_scores, top_k=top_k, sort_top_k=sort_top_k
diff --git a/tests/unit/test_python_utils.py b/tests/unit/test_python_utils.py
index c1686c6102..74552d34a9 100644
--- a/tests/unit/test_python_utils.py
+++ b/tests/unit/test_python_utils.py
@@ -10,7 +10,8 @@
jaccard,
lift,
get_top_k_scored_items,
- binarize
+ binarize,
+ rescale,
)
TOL = 0.0001
@@ -45,9 +46,13 @@ def target_matrices(scope="module"):
@pytest.fixture(scope="module")
-def python_data():
- cooccurrence1 = np.array([[1.0, 0.0, 1.0], [0.0, 2.0, 1.0], [1.0, 1.0, 2.0]])
- cooccurrence2 = np.array(
+def cooccurrence1():
+ return np.array([[1.0, 0.0, 1.0], [0.0, 2.0, 1.0], [1.0, 1.0, 2.0]])
+
+
+@pytest.fixture(scope="module")
+def cooccurrence2():
+ return np.array(
[
[2.0, 0.0, 0.0, 1.0],
[0.0, 3.0, 0.0, 0.0],
@@ -55,11 +60,14 @@ def python_data():
[1.0, 0.0, 2.0, 4.0],
]
)
- return cooccurrence1, cooccurrence2
-def test_python_jaccard(python_data, target_matrices):
- cooccurrence1, cooccurrence2 = python_data
+@pytest.fixture(scope="module")
+def scores():
+ return np.array([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 5, 3, 4, 2]])
+
+
+def test_python_jaccard(cooccurrence1, cooccurrence2, target_matrices):
J1 = jaccard(cooccurrence1)
assert type(J1) == np.ndarray
assert J1 == target_matrices["jaccard1"]
@@ -69,8 +77,7 @@ def test_python_jaccard(python_data, target_matrices):
assert J2 == target_matrices["jaccard2"]
-def test_python_lift(python_data, target_matrices):
- cooccurrence1, cooccurrence2 = python_data
+def test_python_lift(cooccurrence1, cooccurrence2, target_matrices):
L1 = lift(cooccurrence1)
assert type(L1) == np.ndarray
assert L1 == target_matrices["lift1"]
@@ -87,8 +94,7 @@ def test_exponential_decay():
assert np.allclose(actual, expected, atol=TOL)
-def test_get_top_k_scored_items():
- scores = np.array([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 5, 3, 4, 2]])
+def test_get_top_k_scored_items(scores):
top_items, top_scores = get_top_k_scored_items(
scores=scores, top_k=3, sort_top_k=True
)
@@ -110,3 +116,24 @@ def test_binarize():
[1, 1, 1]]
)
assert np.array_equal(binarize(data, threshold), expected)
+
+
+def test_rescale(scores):
+ expected = np.array(
+ [[0, 0.25, 0.5, 0.75, 1], [1, 0.75, 0.5, 0.25, 0], [0, 1, 0.5, 0.75, 0.25]]
+ )
+ assert np.allclose(expected, rescale(scores, 0, 1))
+
+ expected = np.array([[3, 5, 7, 9, 11.0], [11, 9, 7, 5, 3], [3, 11, 7, 9, 5]])
+ assert np.allclose(expected, rescale(scores, 1, 11, 0, 5))
+
+ expected = np.array(
+ [
+ [0, 0.2, 0.4, 0.6, 0.8],
+ [0.625, 0.5, 0.375, 0.25, 0.125],
+ [0, 1, 0.5, 0.75, 0.25],
+ ]
+ )
+ data_min = np.tile(np.array([1, 0, 1])[:, np.newaxis], scores.shape[1])
+ data_max = np.tile(np.array([6, 8, 5])[:, np.newaxis], scores.shape[1])
+ assert np.allclose(expected, rescale(scores, 0, 1, data_min, data_max))
diff --git a/tests/unit/test_sar_singlenode.py b/tests/unit/test_sar_singlenode.py
index f08c1ea096..44d22fa303 100644
--- a/tests/unit/test_sar_singlenode.py
+++ b/tests/unit/test_sar_singlenode.py
@@ -292,25 +292,41 @@ def test_get_normalized_scores(header):
model = SARSingleNode(**header, timedecay_formula=True, normalize=True)
model.fit(train)
- actual = model.score(test, remove_seen=True, normalize=True)
+ actual = model.score(test, remove_seen=True)
expected = np.array(
[
- [-np.inf, -np.inf, -np.inf, -np.inf, 3.0, 3.0, 3.0],
- [-np.inf, 3.0, 3.0, 3.0, -np.inf, -np.inf, -np.inf],
+ [-np.inf, -np.inf, -np.inf, -np.inf, 1.23512374, 1.23512374, 1.23512374],
+ [-np.inf, 1.23512374, 1.23512374, 1.23512374, -np.inf, -np.inf, -np.inf],
]
)
assert actual.shape == (2, 7)
assert isinstance(actual, np.ndarray)
- assert np.isclose(expected, actual).all()
+ assert np.isclose(expected, np.asarray(actual)).all()
- actual = model.score(test, normalize=True)
+ actual = model.score(test)
expected = np.array(
[
- [3.80000633, 4.14285448, 4.14285448, 4.14285448, 3.0, 3.0, 3.0],
- [2.8000859, 3.0, 3.0, 3.0, 2.71441353, 2.71441353, 2.71441353],
+ [
+ 3.11754872,
+ 4.29408577,
+ 4.29408577,
+ 4.29408577,
+ 1.23512374,
+ 1.23512374,
+ 1.23512374,
+ ],
+ [
+ 2.5293308,
+ 1.23511758,
+ 1.23511758,
+ 1.23511758,
+ 3.11767458,
+ 3.11767458,
+ 3.11767458,
+ ],
]
)
assert actual.shape == (2, 7)
assert isinstance(actual, np.ndarray)
- assert np.isclose(expected, actual).all()
+ assert np.isclose(expected, np.asarray(actual)).all()