diff --git a/README.md b/README.md index f5419bf26b..782b50d283 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ We provide a [benchmark notebook](examples/06_benchmarks/movielens.ipynb) to ill | --- | --- | --- | --- | --- | --- | --- | --- | --- | | [ALS](examples/00_quick_start/als_movielens.ipynb) | 0.004732 | 0.044239 | 0.048462 | 0.017796 | 0.965038 | 0.753001 | 0.255647 | 0.251648 | | [SVD](examples/02_model_collaborative_filtering/surprise_svd_deep_dive.ipynb) | 0.012873 | 0.095930 | 0.091198 | 0.032783 | 0.938681 | 0.742690 | 0.291967 | 0.291971 | -| [SAR](examples/00_quick_start/sar_movielens.ipynb) | 0.113028 | 0.388321 | 0.333828 | 0.183179 | N/A | N/A | N/A | N/A | +| [SAR](examples/00_quick_start/sar_movielens.ipynb) | 0.110591 | 0.382461 | 0.330753 | 0.176385 | 1.253805 | 1.048484 | -0.569363 | 0.030474 | | [NCF](examples/02_model_hybrid/ncf_deep_dive.ipynb) | 0.107720 | 0.396118 | 0.347296 | 0.180775 | N/A | N/A | N/A | N/A | | [BPR](examples/02_model_collaborative_filtering/cornac_bpr_deep_dive.ipynb) | 0.105365 | 0.389948 | 0.349841 | 0.181807 | N/A | N/A | N/A | N/A | | [FastAI](examples/00_quick_start/fastai_movielens.ipynb) | 0.025503 | 0.147866 | 0.130329 | 0.053824 | 0.943084 | 0.744337 | 0.285308 | 0.287671 | diff --git a/examples/00_quick_start/sar_movielens.ipynb b/examples/00_quick_start/sar_movielens.ipynb index 14b5305fed..580cd99950 100644 --- a/examples/00_quick_start/sar_movielens.ipynb +++ b/examples/00_quick_start/sar_movielens.ipynb @@ -41,20 +41,23 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "System version: 3.7.3 | packaged by conda-forge | (default, Jul 1 2019, 21:52:21) \n", - "[GCC 7.3.0]\n", - "Pandas version: 0.23.4\n" + "System version: 3.6.10 |Anaconda, Inc.| (default, May 7 2020, 23:06:31) \n", + "[GCC 4.2.1 Compatible Clang 4.0.1 (tags/RELEASE_401/final)]\n", + "Pandas version: 0.25.3\n" ] } ], "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", "# set the environment path to find Recommenders\n", "import sys\n", "sys.path.append(\"../../\")\n", @@ -63,11 +66,23 @@ "import numpy as np\n", "import pandas as pd\n", "import papermill as pm\n", + "from sklearn.preprocessing import minmax_scale\n", "\n", + "from reco_utils.common.python_utils import binarize\n", "from reco_utils.common.timer import Timer\n", "from reco_utils.dataset import movielens\n", "from reco_utils.dataset.python_splitters import python_stratified_split\n", - "from reco_utils.evaluation.python_evaluation import map_at_k, ndcg_at_k, precision_at_k, recall_at_k\n", + "from reco_utils.evaluation.python_evaluation import (\n", + " map_at_k,\n", + " ndcg_at_k,\n", + " precision_at_k,\n", + " recall_at_k,\n", + " rmse,\n", + " mae,\n", + " logloss,\n", + " rsquared,\n", + " exp_var\n", + ")\n", "from reco_utils.recommender.sar import SAR\n", "\n", "print(\"System version: {}\".format(sys.version))\n", @@ -90,7 +105,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": { "tags": [ "parameters" @@ -114,93 +129,22 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 4.81k/4.81k [00:02<00:00, 1.90kKB/s]\n" + "100%|██████████| 4.81k/4.81k [00:02<00:00, 1.67kKB/s]\n" ] }, { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
userIDitemIDratingtimestamp
01962423.0881250949
11863023.0891717742
2223771.0878887116
3244512.0880606923
41663461.0886397596
\n", - "
" - ], - "text/plain": [ - " userID itemID rating timestamp\n", - "0 196 242 3.0 881250949\n", - "1 186 302 3.0 891717742\n", - "2 22 377 1.0 878887116\n", - "3 244 51 2.0 880606923\n", - "4 166 346 1.0 886397596" - ] + "text/plain": " userID itemID rating timestamp\n0 196 242 3.0 881250949\n1 186 302 3.0 891717742\n2 22 377 1.0 878887116\n3 244 51 2.0 880606923\n4 166 346 1.0 886397596", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
userIDitemIDratingtimestamp
01962423.0881250949
11863023.0891717742
2223771.0878887116
3244512.0880606923
41663461.0886397596
\n
" }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -227,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -236,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -298,7 +242,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -312,7 +256,8 @@ " col_timestamp=\"timestamp\",\n", " similarity_type=\"jaccard\", \n", " time_decay_coefficient=30, \n", - " timedecay_formula=True\n", + " timedecay_formula=True,\n", + " normalize=True\n", ")" ] }, @@ -333,14 +278,29 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-08-18 11:31:24,687 INFO Collecting user affinity matrix\n", + "2020-08-18 11:31:24,690 INFO Calculating time-decayed affinities\n", + "2020-08-18 11:31:24,722 INFO Creating index columns\n", + "2020-08-18 11:31:24,822 INFO Calculating normalization factors\n", + "2020-08-18 11:31:24,862 INFO Building user affinity sparse matrix\n", + "2020-08-18 11:31:24,868 INFO Calculating item co-occurrence\n", + "2020-08-18 11:31:25,063 INFO Calculating item similarity\n", + "2020-08-18 11:31:25,063 INFO Using jaccard based similarity\n", + "2020-08-18 11:31:25,158 INFO Done training\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Took 0.3302565817721188 seconds for training.\n" + "Took 0.4742812399927061 seconds for training.\n" ] } ], @@ -353,14 +313,22 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-08-18 11:31:25,202 INFO Calculating recommendation scores\n", + "2020-08-18 11:31:25,389 INFO Removing seen items\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Took 0.21034361701458693 seconds for prediction.\n" + "Took 0.22830537598929368 seconds for prediction.\n" ] } ], @@ -373,82 +341,17 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
userIDitemIDprediction
012043.313306
11893.280465
21113.233867
313673.192575
414233.131517
\n", - "
" - ], - "text/plain": [ - " userID itemID prediction\n", - "0 1 204 3.313306\n", - "1 1 89 3.280465\n", - "2 1 11 3.233867\n", - "3 1 367 3.192575\n", - "4 1 423 3.131517" - ] + "text/plain": " userID itemID prediction\n0 1 204 3.231405\n1 1 89 3.199445\n2 1 11 3.154097\n3 1 367 3.113913\n4 1 423 3.054493", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
userIDitemIDprediction
012043.231405
11893.199445
21113.154097
313673.113913
414233.054493
\n
" }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -468,7 +371,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -477,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -486,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -495,17 +398,95 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "eval_recall = recall_at_k(test, top_k, col_user='userID', col_item='itemID', col_rating='rating', k=TOP_K)" ] }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "eval_rmse = rmse(test, top_k, col_user='userID', col_item='itemID', col_rating='rating')" + ] + }, { "cell_type": "code", "execution_count": 15, - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "eval_mae = mae(test, top_k, col_user='userID', col_item='itemID', col_rating='rating')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "eval_rsquared = rsquared(test, top_k, col_user='userID', col_item='itemID', col_rating='rating')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "eval_exp_var = exp_var(test, top_k, col_user='userID', col_item='itemID', col_rating='rating')" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "positivity_threshold = 2\n", + "test_bin = test.copy()\n", + "test_bin['rating'] = binarize(test_bin['rating'], positivity_threshold)\n", + "\n", + "top_k_prob = top_k.copy()\n", + "top_k_prob['prediction'] = minmax_scale(\n", + " top_k_prob['prediction'].astype(float)\n", + ")\n", + "\n", + "eval_logloss = logloss(test_bin, top_k_prob, col_user='userID', col_item='itemID', col_rating='rating')" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "name": "stdout", @@ -516,7 +497,12 @@ "MAP:\t0.110591\n", "NDCG:\t0.382461\n", "Precision@K:\t0.330753\n", - "Recall@K:\t0.176385\n" + "Recall@K:\t0.176385\n", + "RMSE:\t1.253805\n", + "MAE:\t1.048484\n", + "R2:\t-0.569363\n", + "Exp var:\t0.030474\n", + "Logloss:\t0.542861\n" ] } ], @@ -526,97 +512,38 @@ " \"MAP:\\t%f\" % eval_map,\n", " \"NDCG:\\t%f\" % eval_ndcg,\n", " \"Precision@K:\\t%f\" % eval_precision,\n", - " \"Recall@K:\\t%f\" % eval_recall, sep='\\n')" + " \"Recall@K:\\t%f\" % eval_recall,\n", + " \"RMSE:\\t%f\" % eval_rmse,\n", + " \"MAE:\\t%f\" % eval_mae,\n", + " \"R2:\\t%f\" % eval_rsquared,\n", + " \"Exp var:\\t%f\" % eval_exp_var,\n", + " \"Logloss:\\t%f\" % eval_logloss,\n", + " sep='\\n')" ] }, { "cell_type": "code", - "execution_count": 16, - "metadata": {}, + "execution_count": 23, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-08-18 11:33:07,631 INFO Calculating recommendation scores\n", + "2020-08-18 11:33:07,643 INFO Removing seen items\n" + ] + }, { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
userIDitemIDratingtimestampprediction
08765235.0879428378NaN
18765294.0879428451NaN
28761744.08794283780.353567
38762764.0879428354NaN
48762883.0879428101NaN
\n", - "
" - ], - "text/plain": [ - " userID itemID rating timestamp prediction\n", - "0 876 523 5.0 879428378 NaN\n", - "1 876 529 4.0 879428451 NaN\n", - "2 876 174 4.0 879428378 0.353567\n", - "3 876 276 4.0 879428354 NaN\n", - "4 876 288 3.0 879428101 NaN" - ] + "text/plain": " userID itemID rating timestamp prediction\n0 876 523 5.0 879428378 NaN\n1 876 529 4.0 879428451 NaN\n2 876 174 4.0 879428378 3.702239\n3 876 276 4.0 879428354 NaN\n4 876 288 3.0 879428101 NaN", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
userIDitemIDratingtimestampprediction
08765235.0879428378NaN
18765294.0879428451NaN
28761744.08794283783.702239
38762764.0879428354NaN
48762883.0879428101NaN
\n
" }, - "execution_count": 16, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -639,64 +566,13 @@ }, { "cell_type": "code", - "execution_count": 86, - "metadata": {}, - "outputs": [ - { - "data": { - "application/papermill.record+json": { - "map": 0.11059057578638949 - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record+json": { - "ndcg": 0.3824612290501957 - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record+json": { - "precision": 0.33075291622481445 - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record+json": { - "recall": 0.1763854474342893 - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record+json": { - "train_time": 0.284792423248291 - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record+json": { - "test_time": 0.1463017463684082 - } - }, - "metadata": {}, - "output_type": "display_data" + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" } - ], + }, + "outputs": [], "source": [ "# Record results with papermill for tests - ignore this cell\n", "pm.record(\"map\", eval_map)\n", @@ -711,9 +587,9 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3", + "name": "python3", "language": "python", - "name": "python3" + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -725,9 +601,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.6.10" } }, "nbformat": 4, - "nbformat_minor": 2 -} + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/reco_utils/__init__.py b/reco_utils/__init__.py index df052a7fad..3103749f97 100644 --- a/reco_utils/__init__.py +++ b/reco_utils/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. __title__ = "Microsoft Recommenders" -__version__ = "2019.9" +__version__ = "2020.8" __author__ = "RecoDev Team at Microsoft" __license__ = "MIT" __copyright__ = "Copyright 2018-present Microsoft Corporation" diff --git a/reco_utils/common/python_utils.py b/reco_utils/common/python_utils.py index a365618ea3..3f10d0d409 100644 --- a/reco_utils/common/python_utils.py +++ b/reco_utils/common/python_utils.py @@ -116,3 +116,26 @@ def binarize(a, threshold): 0.0 ) + +def rescale(data, new_min=0, new_max=1, data_min=None, data_max=None): + """ + Rescale/normalize the data to be within the range [new_min, new_max] + If data_min and data_max are explicitly provided, they will be used + as the old min/max values instead of taken from the data. + + Note: this is same as the scipy.MinMaxScaler with the exception that we can override + the min/max of the old scale. + + Args: + data (np.array): 1d scores vector or 2d score matrix (users x items). + new_min (int|float): The minimum of the newly scaled data. + new_max (int|float): The maximum of the newly scaled data. + data_min (None|number): The minimum of the passed data [if omitted it will be inferred]. + data_max (None|number): The maximum of the passed data [if omitted it will be inferred]. + + Returns: + np.array: The newly scaled/normalized data. + """ + data_min = data.min() if data_min is None else data_min + data_max = data.max() if data_max is None else data_max + return (data - data_min) / (data_max - data_min) * (new_max - new_min) + new_min diff --git a/reco_utils/recommender/sar/sar_singlenode.py b/reco_utils/recommender/sar/sar_singlenode.py index 46685a3d3f..298c990eda 100644 --- a/reco_utils/recommender/sar/sar_singlenode.py +++ b/reco_utils/recommender/sar/sar_singlenode.py @@ -12,6 +12,7 @@ lift, exponential_decay, get_top_k_scored_items, + rescale, ) from reco_utils.common import constants @@ -89,7 +90,6 @@ def __init__( # set flag to capture unity-rating user-affinity matrix for scaling scores self.normalize = normalize self.col_unity_rating = "_unity_rating" - self.unity_user_affinity = None # column for mapping user / item ids to internal indices self.col_item_id = "_indexed_items" @@ -99,6 +99,10 @@ def __init__( self.n_users = None self.n_items = None + # The min and max of the rating scale, obtained from the training data. + self.rating_min = None + self.rating_max = None + # mapping for item to matrix element self.user2index = None self.item2index = None @@ -239,6 +243,8 @@ def fit(self, df): ) if self.normalize: + self.rating_min = temp_df[self.col_rating].min() + self.rating_max = temp_df[self.col_rating].max() logger.info("Calculating normalization factors") temp_df[self.col_unity_rating] = 1.0 if self.time_decay_flag: @@ -286,13 +292,12 @@ def fit(self, df): logger.info("Done training") - def score(self, test, remove_seen=False, normalize=False): + def score(self, test, remove_seen=False): """Score all items for test users. Args: test (pd.DataFrame): user to test remove_seen (bool): flag to remove items seen in training from recommendation - normalize (bool): flag to normalize scores to be in the same scale as the original ratings Returns: np.ndarray: Value of interest of all items for the users. @@ -316,25 +321,29 @@ def score(self, test, remove_seen=False, normalize=False): if isinstance(test_scores, sparse.spmatrix): test_scores = test_scores.toarray() + if self.normalize: + counts = self.unity_user_affinity[user_ids, :].dot(self.item_similarity) + user_min_scores = ( + np.tile(counts.min(axis=1)[:, np.newaxis], test_scores.shape[1]) + * self.rating_min + ) + user_max_scores = ( + np.tile(counts.max(axis=1)[:, np.newaxis], test_scores.shape[1]) + * self.rating_max + ) + test_scores = rescale( + test_scores, + self.rating_min, + self.rating_max, + user_min_scores, + user_max_scores, + ) + # remove items in the train set so recommended items are always novel if remove_seen: logger.info("Removing seen items") test_scores += self.user_affinity[user_ids, :] * -np.inf - if normalize: - if self.unity_user_affinity is None: - raise ValueError( - "Cannot use normalize flag during scoring if it was not set at model instantiation" - ) - else: - test_scores = np.array( - np.divide( - test_scores, - self.unity_user_affinity[user_ids, :].dot(self.item_similarity), - ) - ) - test_scores = np.where(np.isnan(test_scores), -np.inf, test_scores) - return test_scores def get_popularity_based_topk(self, top_k=10, sort_top_k=True): @@ -438,9 +447,7 @@ def get_item_based_topk(self, items, top_k=10, sort_top_k=True): # drop invalid items return df.replace(-np.inf, np.nan).dropna() - def recommend_k_items( - self, test, top_k=10, sort_top_k=True, remove_seen=False, normalize=False - ): + def recommend_k_items(self, test, top_k=10, sort_top_k=True, remove_seen=False): """Recommend top K items for all users which are in the test set Args: @@ -453,7 +460,7 @@ def recommend_k_items( pd.DataFrame: top k recommendation items for each user """ - test_scores = self.score(test, remove_seen=remove_seen, normalize=normalize) + test_scores = self.score(test, remove_seen=remove_seen) top_items, top_scores = get_top_k_scored_items( scores=test_scores, top_k=top_k, sort_top_k=sort_top_k diff --git a/tests/unit/test_python_utils.py b/tests/unit/test_python_utils.py index c1686c6102..74552d34a9 100644 --- a/tests/unit/test_python_utils.py +++ b/tests/unit/test_python_utils.py @@ -10,7 +10,8 @@ jaccard, lift, get_top_k_scored_items, - binarize + binarize, + rescale, ) TOL = 0.0001 @@ -45,9 +46,13 @@ def target_matrices(scope="module"): @pytest.fixture(scope="module") -def python_data(): - cooccurrence1 = np.array([[1.0, 0.0, 1.0], [0.0, 2.0, 1.0], [1.0, 1.0, 2.0]]) - cooccurrence2 = np.array( +def cooccurrence1(): + return np.array([[1.0, 0.0, 1.0], [0.0, 2.0, 1.0], [1.0, 1.0, 2.0]]) + + +@pytest.fixture(scope="module") +def cooccurrence2(): + return np.array( [ [2.0, 0.0, 0.0, 1.0], [0.0, 3.0, 0.0, 0.0], @@ -55,11 +60,14 @@ def python_data(): [1.0, 0.0, 2.0, 4.0], ] ) - return cooccurrence1, cooccurrence2 -def test_python_jaccard(python_data, target_matrices): - cooccurrence1, cooccurrence2 = python_data +@pytest.fixture(scope="module") +def scores(): + return np.array([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 5, 3, 4, 2]]) + + +def test_python_jaccard(cooccurrence1, cooccurrence2, target_matrices): J1 = jaccard(cooccurrence1) assert type(J1) == np.ndarray assert J1 == target_matrices["jaccard1"] @@ -69,8 +77,7 @@ def test_python_jaccard(python_data, target_matrices): assert J2 == target_matrices["jaccard2"] -def test_python_lift(python_data, target_matrices): - cooccurrence1, cooccurrence2 = python_data +def test_python_lift(cooccurrence1, cooccurrence2, target_matrices): L1 = lift(cooccurrence1) assert type(L1) == np.ndarray assert L1 == target_matrices["lift1"] @@ -87,8 +94,7 @@ def test_exponential_decay(): assert np.allclose(actual, expected, atol=TOL) -def test_get_top_k_scored_items(): - scores = np.array([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 5, 3, 4, 2]]) +def test_get_top_k_scored_items(scores): top_items, top_scores = get_top_k_scored_items( scores=scores, top_k=3, sort_top_k=True ) @@ -110,3 +116,24 @@ def test_binarize(): [1, 1, 1]] ) assert np.array_equal(binarize(data, threshold), expected) + + +def test_rescale(scores): + expected = np.array( + [[0, 0.25, 0.5, 0.75, 1], [1, 0.75, 0.5, 0.25, 0], [0, 1, 0.5, 0.75, 0.25]] + ) + assert np.allclose(expected, rescale(scores, 0, 1)) + + expected = np.array([[3, 5, 7, 9, 11.0], [11, 9, 7, 5, 3], [3, 11, 7, 9, 5]]) + assert np.allclose(expected, rescale(scores, 1, 11, 0, 5)) + + expected = np.array( + [ + [0, 0.2, 0.4, 0.6, 0.8], + [0.625, 0.5, 0.375, 0.25, 0.125], + [0, 1, 0.5, 0.75, 0.25], + ] + ) + data_min = np.tile(np.array([1, 0, 1])[:, np.newaxis], scores.shape[1]) + data_max = np.tile(np.array([6, 8, 5])[:, np.newaxis], scores.shape[1]) + assert np.allclose(expected, rescale(scores, 0, 1, data_min, data_max)) diff --git a/tests/unit/test_sar_singlenode.py b/tests/unit/test_sar_singlenode.py index f08c1ea096..44d22fa303 100644 --- a/tests/unit/test_sar_singlenode.py +++ b/tests/unit/test_sar_singlenode.py @@ -292,25 +292,41 @@ def test_get_normalized_scores(header): model = SARSingleNode(**header, timedecay_formula=True, normalize=True) model.fit(train) - actual = model.score(test, remove_seen=True, normalize=True) + actual = model.score(test, remove_seen=True) expected = np.array( [ - [-np.inf, -np.inf, -np.inf, -np.inf, 3.0, 3.0, 3.0], - [-np.inf, 3.0, 3.0, 3.0, -np.inf, -np.inf, -np.inf], + [-np.inf, -np.inf, -np.inf, -np.inf, 1.23512374, 1.23512374, 1.23512374], + [-np.inf, 1.23512374, 1.23512374, 1.23512374, -np.inf, -np.inf, -np.inf], ] ) assert actual.shape == (2, 7) assert isinstance(actual, np.ndarray) - assert np.isclose(expected, actual).all() + assert np.isclose(expected, np.asarray(actual)).all() - actual = model.score(test, normalize=True) + actual = model.score(test) expected = np.array( [ - [3.80000633, 4.14285448, 4.14285448, 4.14285448, 3.0, 3.0, 3.0], - [2.8000859, 3.0, 3.0, 3.0, 2.71441353, 2.71441353, 2.71441353], + [ + 3.11754872, + 4.29408577, + 4.29408577, + 4.29408577, + 1.23512374, + 1.23512374, + 1.23512374, + ], + [ + 2.5293308, + 1.23511758, + 1.23511758, + 1.23511758, + 3.11767458, + 3.11767458, + 3.11767458, + ], ] ) assert actual.shape == (2, 7) assert isinstance(actual, np.ndarray) - assert np.isclose(expected, actual).all() + assert np.isclose(expected, np.asarray(actual)).all()