From 31c906cd314e713931ea8b778ffb60f16c1ed3ef Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 20 Aug 2021 10:48:46 +0800 Subject: [PATCH 1/4] REFACTOR: rename `Kg_Seq_Dataset` to `KGSeqDataset`. --- recbole/data/dataset/__init__.py | 2 +- recbole/data/dataset/customized_dataset.py | 6 +++--- recbole/data/dataset/kg_dataset.py | 3 --- recbole/data/dataset/kg_seq_dataset.py | 2 +- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/recbole/data/dataset/__init__.py b/recbole/data/dataset/__init__.py index ca1ce3a92..e1f9d8c82 100644 --- a/recbole/data/dataset/__init__.py +++ b/recbole/data/dataset/__init__.py @@ -1,6 +1,6 @@ from recbole.data.dataset.dataset import Dataset from recbole.data.dataset.sequential_dataset import SequentialDataset from recbole.data.dataset.kg_dataset import KnowledgeBasedDataset -from recbole.data.dataset.kg_seq_dataset import Kg_Seq_Dataset +from recbole.data.dataset.kg_seq_dataset import KGSeqDataset from recbole.data.dataset.decisiontree_dataset import DecisionTreeDataset from recbole.data.dataset.customized_dataset import * diff --git a/recbole/data/dataset/customized_dataset.py b/recbole/data/dataset/customized_dataset.py index 3dd7f0546..0b93e4306 100644 --- a/recbole/data/dataset/customized_dataset.py +++ b/recbole/data/dataset/customized_dataset.py @@ -19,19 +19,19 @@ import numpy as np import torch -from recbole.data.dataset import Kg_Seq_Dataset, SequentialDataset +from recbole.data.dataset import KGSeqDataset, SequentialDataset from recbole.data.interaction import Interaction from recbole.sampler import SeqSampler from recbole.utils.enum_type import FeatureType -class GRU4RecKGDataset(Kg_Seq_Dataset): +class GRU4RecKGDataset(KGSeqDataset): def __init__(self, config): super().__init__(config) -class KSRDataset(Kg_Seq_Dataset): +class KSRDataset(KGSeqDataset): def __init__(self, config): super().__init__(config) diff --git a/recbole/data/dataset/kg_dataset.py b/recbole/data/dataset/kg_dataset.py index 38542c354..c43dfcd56 100644 --- a/recbole/data/dataset/kg_dataset.py +++ b/recbole/data/dataset/kg_dataset.py @@ -148,9 +148,6 @@ def _build_feat_name_list(self): feat_name_list.append('kg_feat') return feat_name_list - def save(self, filepath): - raise NotImplementedError() - def _load_kg(self, token, dataset_path): self.logger.debug(set_color(f'Loading kg from [{dataset_path}].', 'green')) kg_path = os.path.join(dataset_path, f'{token}.kg') diff --git a/recbole/data/dataset/kg_seq_dataset.py b/recbole/data/dataset/kg_seq_dataset.py index ac4ecf108..288c95a03 100644 --- a/recbole/data/dataset/kg_seq_dataset.py +++ b/recbole/data/dataset/kg_seq_dataset.py @@ -10,7 +10,7 @@ from recbole.data.dataset import SequentialDataset, KnowledgeBasedDataset -class Kg_Seq_Dataset(SequentialDataset, KnowledgeBasedDataset): +class KGSeqDataset(SequentialDataset, KnowledgeBasedDataset): """Containing both processing of Sequential Models and Knowledge-based Models. Inherit from :class:`~recbole.data.dataset.sequential_dataset.SequentialDataset` and From fa6f94dd2779299b4d096e5240f5ed5b9a318554 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 20 Aug 2021 22:38:30 +0800 Subject: [PATCH 2/4] FIX: enable command line args to accept `None`. --- recbole/config/configurator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recbole/config/configurator.py b/recbole/config/configurator.py index e05df6166..c9e30755c 100644 --- a/recbole/config/configurator.py +++ b/recbole/config/configurator.py @@ -113,7 +113,7 @@ def _convert_config_dict(self, config_dict): continue try: value = eval(param) - if not isinstance(value, (str, int, float, list, tuple, dict, bool, Enum)): + if value is not None and not isinstance(value, (str, int, float, list, tuple, dict, bool, Enum)): value = param except (NameError, SyntaxError, TypeError): if isinstance(param, str): From 4cdd724961c95181a8aa35d0f8cd92d5fcf82b78 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Sat, 21 Aug 2021 11:25:50 +0800 Subject: [PATCH 3/4] DOC: fix docs --- recbole/evaluator/metrics.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/recbole/evaluator/metrics.py b/recbole/evaluator/metrics.py index 9dfca2678..bdc4d4cf6 100644 --- a/recbole/evaluator/metrics.py +++ b/recbole/evaluator/metrics.py @@ -8,7 +8,7 @@ # @Author : Kaiyuan Li, Zhichao Feng, Xingyu Pan, Zihan Lin # @email : tsotfsk@outlook.com, fzcbupt@gmail.com, panxy@ruc.edu.cn, zhlin@ruc.edu.cn -""" +r""" recbole.evaluator.metrics ############################ @@ -227,7 +227,7 @@ class GAUC(AbstractMetric): the area under the ROC curve grouped by user. We weighted the index of each user :math:`u` by the number of positive samples of users to get the final result. - For further details, please refer to the `paper `_ + For further details, please refer to the `paper `__ Note: It calculates the AUC score of each user, and finally obtains GAUC by weighting the user AUC. @@ -421,8 +421,8 @@ class ItemCoverage(AbstractMetric): .. _ItemCoverage: https://en.wikipedia.org/wiki/Coverage_(information_systems) - For further details, please refer to the `paper `_ - and `paper `_. + For further details, please refer to the `paper `__ + and `paper `__. .. math:: \mathrm{Coverage@K}=\frac{\left| \bigcup_{u \in U} \hat{R}(u) \right|}{|I|} @@ -462,8 +462,8 @@ def get_coverage(self, item_matrix, num_items): class AveragePopularity(AbstractMetric): r"""AveragePopularity computes the average popularity of recommended items. - For further details, please refer to the `paper `_ - and `paper `_. + For further details, please refer to the `paper `__ + and `paper `__. .. math:: \mathrm{AveragePopularity@K}=\frac{1}{|U|} \sum_{u \in U } \frac{\sum_{i \in R_{u}} \phi(i)}{|R_{u}|} @@ -530,8 +530,8 @@ class ShannonEntropy(AbstractMetric): .. _ShannonEntropy: https://en.wikipedia.org/wiki/Entropy_(information_theory) - For further details, please refer to the `paper `_ - and `paper `_ + For further details, please refer to the `paper `__ + and `paper `__ .. math:: \mathrm {ShannonEntropy@K}=-\sum_{i=1}^{|I|} p(i) \log p(i) @@ -582,7 +582,7 @@ class GiniIndex(AbstractMetric): .. _GiniIndex: https://en.wikipedia.org/wiki/Gini_coefficient - For further details, please refer to the `paper `_. + For further details, please refer to the `paper `__. .. math:: \mathrm {GiniIndex@K}=\left(\frac{\sum_{i=1}^{|I|}(2 i-|I|-1) P{(i)}}{|I| \sum_{i=1}^{|I|} P{(i)}}\right) @@ -633,7 +633,7 @@ class TailPercentage(AbstractMetric): .. _TailPercentage: https://en.wikipedia.org/wiki/Long_tail#Criticisms - For further details, please refer to the `paper `_. + For further details, please refer to the `paper `__. .. math:: \mathrm {TailPercentage@K}=\frac{1}{|U|} \sum_{u \in U} \frac{\sum_{i \in R_{u}} {\delta(i \in T)}}{|R_{u}|} From e3832fa4efcb70ac46051f51bbbc3145454990e6 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Sun, 22 Aug 2021 17:45:17 +0800 Subject: [PATCH 4/4] FEA: add `load_data_and_model`, refactor case_study.py, add case_study.rst and add save_and_load_data_and_model.rst. --- .../developer_guide/customize_metrics.rst | 6 +- docs/source/index.rst | 1 + docs/source/user_guide/config_settings.rst | 6 ++ docs/source/user_guide/usage.rst | 4 +- docs/source/user_guide/usage/case_study.rst | 81 +++++++++++++++++++ .../usage/save_and_load_data_and_model.rst | 67 +++++++++++++++ recbole/data/dataset/dataset.py | 12 +-- recbole/properties/overall.yaml | 2 + recbole/quick_start/__init__.py | 2 +- recbole/quick_start/quick_start.py | 61 +++++++++++++- recbole/utils/argument_list.py | 4 +- recbole/utils/case_study.py | 38 +++++---- run_example/case_study_example.py | 31 ++----- run_example/save_and_load_example.py | 77 +++++++----------- 14 files changed, 288 insertions(+), 104 deletions(-) create mode 100644 docs/source/user_guide/usage/case_study.rst create mode 100644 docs/source/user_guide/usage/save_and_load_data_and_model.rst diff --git a/docs/source/developer_guide/customize_metrics.rst b/docs/source/developer_guide/customize_metrics.rst index b0c789dd7..1d78b28b8 100644 --- a/docs/source/developer_guide/customize_metrics.rst +++ b/docs/source/developer_guide/customize_metrics.rst @@ -7,7 +7,7 @@ Here, it only takes three steps to incorporate a new metric and we introduce the Sign in Your Metric in Register ------------------------------ +-------------------------------- To begin with, we must add a new line in :obj:`~recbole.evaluator.register.metric_information`: All the metrics are registered by :obj:`metric_information` which is a dict. Keys are the name of metrics and should be lowercase. Value is a list which contain one or multiple string that corresponding @@ -47,7 +47,7 @@ and the total item number, we can sign in the metric as follow. Create a New Metric Class ------------------------ +-------------------------- Then, we create a new class in the file :file:`~recbole.evaluator.metrics` and define the parameter in ``__init__()`` @@ -59,7 +59,7 @@ Then, we create a new class in the file :file:`~recbole.evaluator.metrics` and d Implement calculate_metric(self, dataobject) ------------------------------- +--------------------------------------------- All the computational process is defined in this function. The args is a packaged data object that contains all the result above. We can treat it as a dict and get data from it by ``rec_items = dataobject.get('rec.items')`` . The returned value should be a dict with key of metric name diff --git a/docs/source/index.rst b/docs/source/index.rst index 4a41dfe09..91953e455 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -32,6 +32,7 @@ RecBole v0.2.0 developer_guide/customize_trainers developer_guide/customize_dataloaders developer_guide/customize_samplers + developer_guide/customize_metrics .. toctree:: diff --git a/docs/source/user_guide/config_settings.rst b/docs/source/user_guide/config_settings.rst index 52d8a135e..6d268334c 100644 --- a/docs/source/user_guide/config_settings.rst +++ b/docs/source/user_guide/config_settings.rst @@ -34,6 +34,12 @@ model training and evaluation. Defaults to ``'saved/'``. - ``show_progress (bool)`` : Show the progress of training epoch and evaluate epoch. Defaults to ``True``. +- ``save_dataset (bool)``: Whether or not save filtered dataset. + If True, save filtered dataset, otherwise it will not be saved. + Defaults to ``False``. +- ``save_dataloaders (bool)``: Whether or not save split dataloaders. + If True, save split dataloaders, otherwise they will not be saved. + Defaults to ``False``. **Training Setting** diff --git a/docs/source/user_guide/usage.rst b/docs/source/user_guide/usage.rst index f98dd081a..4c0a0db1e 100644 --- a/docs/source/user_guide/usage.rst +++ b/docs/source/user_guide/usage.rst @@ -11,4 +11,6 @@ Here we introduce how to use RecBole. usage/running_new_dataset usage/running_different_models usage/qa - usage/load_pretrained_embedding \ No newline at end of file + usage/load_pretrained_embedding + usage/save_and_load_data_and_model + usage/case_study \ No newline at end of file diff --git a/docs/source/user_guide/usage/case_study.rst b/docs/source/user_guide/usage/case_study.rst new file mode 100644 index 000000000..e6ab618e0 --- /dev/null +++ b/docs/source/user_guide/usage/case_study.rst @@ -0,0 +1,81 @@ +Case study +============= + +Case study is an in-depth study of the performance of a specific recommendation algorithm, +which will analysis the recommendation result of some users. +In RecBole, we implemented :meth:`~recbole.utils.case_study.full_sort_scores` +and :meth:`~recbole.utils.case_study.full_sort_topk` for case study purpose. +In this section, we will present a typical usage of these two functions. + +Reload model +------------- + +First, we need to reload the recommendation model, +we can use :meth:`~recbole.quick_start.quick_start.load_data_and_model` to load saved data and model. + +.. code:: python3 + + config, model, dataset, train_data, valid_data, test_data = load_data_and_model( + model_file='../saved/BPR-Aug-20-2021_03-32-13.pth', + ) # Here you can replace it by your model path. + +Convert external user id into internal user id +------------------------------------------------- + +Then, we need to use :meth:`~recbole.data.dataset.dataset.Dataset.token2id` +to convert external user id which we want to do case study into internal user id. + +.. code:: python3 + + uid_series = dataset.token2id(dataset.uid_field, ['196', '186']) + +Get scores of every user-item pairs +------------------------------------- + +If we want to calculate the scores of every user-item pairs for given user, +we can call :meth:`~recbole.utils.case_study.full_sort_scores` function to get the scores matrix. + +.. code:: python3 + + score = full_sort_scores(uid_series, model, test_data, device=config['device']) + print(score) # score of all items + print(score[0, dataset.token2id(dataset.iid_field, ['242', '302'])]) + # score of item ['242', '302'] for user '196'. + +The output will be like this: + +.. code:: none + + tensor([[ -inf, -inf, 0.1074, ..., -0.0966, -0.1217, -0.0966], + [ -inf, -0.0013, -inf, ..., -0.1115, -0.1089, -0.1196]], + device='cuda:0') + tensor([ -inf, 0.1074], device='cuda:0') + +Note that the score of ``[pad]`` and history items (for non-repeatable recommendation) will be set into ``-inf``. + +Get the top ranked item for each user +-------------------------------------- + +If we want to get the top ranked item for given user, +we can call :meth:`~recbole.utils.case_study.full_sort_topk` function to get the scores and internal ids of these items. + +.. code:: python3 + + topk_score, topk_iid_list = full_sort_topk(uid_series, model, test_data, k=10, device=config['device']) + print(topk_score) # scores of top 10 items + print(topk_iid_list) # internal id of top 10 items + external_item_list = dataset.id2token(dataset.iid_field, topk_iid_list.cpu()) + print(external_item_list) # external tokens of top 10 items + +The output will be like this: + +.. code:: none + + tensor([[0.1985, 0.1947, 0.1850, 0.1849, 0.1822, 0.1770, 0.1770, 0.1765, 0.1752, + 0.1744], + [0.2487, 0.2379, 0.2351, 0.2311, 0.2293, 0.2239, 0.2215, 0.2156, 0.2137, + 0.2114]], device='cuda:0') + tensor([[ 50, 32, 158, 210, 13, 100, 201, 61, 167, 312], + [102, 312, 358, 100, 32, 53, 167, 472, 162, 201]], device='cuda:0') + [['100' '98' '258' '7' '222' '496' '318' '288' '216' '176'] + ['174' '176' '50' '496' '98' '181' '216' '28' '172' '318']] diff --git a/docs/source/user_guide/usage/save_and_load_data_and_model.rst b/docs/source/user_guide/usage/save_and_load_data_and_model.rst new file mode 100644 index 000000000..2ae7e0610 --- /dev/null +++ b/docs/source/user_guide/usage/save_and_load_data_and_model.rst @@ -0,0 +1,67 @@ +Save and load data and model +============================== + +In this section, we will present how to save and load data and model. + +Save data and model +-------------------- + +When we use the :meth:`~recbole.quick_start.quick_start.run_recbole` function mentioned in :doc:`run_recbole`, +it will save the best model parameters in training process and its corresponding config settings. +If you want to save filtered dataset and split dataloaders, +you can set parameter :attr:`save_dataset` and parameter :attr:`save_dataloaders` to ``True`` +to save filtered dataset and split dataloaders. + +You can refer to :doc:`../config_settings` for more details about :attr:`save_dataset` and :attr:`save_dataloaders`. + +Here we present a typical output when two parameters above is ``True``: + +.. code:: none + + 21 Aug 13:05 INFO Saving filtered dataset into [saved/ml-100k-dataset.pth] + 21 Aug 13:05 INFO ml-100k + The number of users: 944 + Average actions of users: 106.04453870625663 + The number of items: 1683 + Average actions of items: 59.45303210463734 + The number of inters: 100000 + The sparsity of the dataset: 93.70575143257098% + Remain Fields: ['user_id', 'item_id', 'rating', 'timestamp'] + 21 Aug 13:05 INFO Saved split dataloaders: saved/ml-100k-for-BPR-dataloader.pth + 21 Aug 13:06 INFO BPR( + (user_embedding): Embedding(944, 64) + (item_embedding): Embedding(1683, 64) + (loss): BPRLoss() + ) + Trainable parameters: 168128 + Train 0: 100%|█████████████████████████| 40/40 [00:01<00:00, 32.52it/s, GPU RAM: 0.01 G/11.91 G] + 21 Aug 13:06 INFO epoch 0 training [time: 1.24s, train loss: 27.7228] + Evaluate : 100%|███████████████████████| 472/472 [00:04<00:00, 94.53it/s, GPU RAM: 0.01 G/11.91 G] + 21 Aug 13:06 INFO epoch 0 evaluating [time: 5.00s, valid_score: 0.020500] + 21 Aug 13:06 INFO valid result: + recall@10 : 0.0067 mrr@10 : 0.0205 ndcg@10 : 0.0086 hit@10 : 0.0732 precision@10 : 0.0081 + 21 Aug 13:06 INFO Saving current best: saved/BPR-Aug-21-2021_13-06-00.pth + + ... + +As we can see, the filtered dataset is saved to ``saved/ml-100k-dataset.pth``, +the split dataloaders are saved to ``saved/ml-100k-for-BPR-dataloader.pth``, +and the model is saved to ``saved/BPR-Aug-21-2021_13-06-00.pth``. + +Load data and model +-------------------- + +If you want to reload the data and model, +you can apply :meth:`~recbole.quick_start.quick_start.load_data_and_model` to get them. +You can also pass :attr:`dataset_file` and :attr:`dataloader_file` to this function to reload data from file, +which can reduce the time of data filtering and data splitting. + +Here we present a typical usage of :meth:`~recbole.quick_start.quick_start.load_data_and_model`: + +.. code:: python3 + + config, model, dataset, train_data, valid_data, test_data = load_data_and_model( + model_file='saved/BPR-Aug-21-2021_13-06-00.pth', + ) + # Here you can replace it by your model path. + # And you can also pass 'dataset_file' and 'dataloader_file' to this function. diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index 8098bf937..8d52a8004 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -1485,16 +1485,10 @@ def build(self): return datasets - def save(self, filepath): - """Saving this :class:`Dataset` object to local path. - - Args: - filepath (str): path of saved dir. + def save(self): + """Saving this :class:`Dataset` object to :attr:`config['checkpoint_dir']`. """ - if (filepath is None) or (not os.path.isdir(filepath)): - raise ValueError(f'Filepath [{filepath}] need to be a dir.') - - file = os.path.join(filepath, f'{self.config["dataset"]}-dataset.pth') + file = os.path.join(self.config['checkpoint_dir'], f'{self.config["dataset"]}-dataset.pth') self.logger.info(set_color('Saving filtered dataset into ', 'pink') + f'[{file}]') with open(file, 'wb') as f: pickle.dump(self, f) diff --git a/recbole/properties/overall.yaml b/recbole/properties/overall.yaml index 7795b8741..5eed6492a 100644 --- a/recbole/properties/overall.yaml +++ b/recbole/properties/overall.yaml @@ -7,6 +7,8 @@ reproducibility: True data_path: 'dataset/' checkpoint_dir: 'saved' show_progress: True +save_dataset: False +save_dataloaders: False # training settings epochs: 300 diff --git a/recbole/quick_start/__init__.py b/recbole/quick_start/__init__.py index b604cec32..38a49c318 100644 --- a/recbole/quick_start/__init__.py +++ b/recbole/quick_start/__init__.py @@ -1 +1 @@ -from recbole.quick_start.quick_start import run_recbole, objective_function +from recbole.quick_start.quick_start import run_recbole, objective_function, load_data_and_model diff --git a/recbole/quick_start/quick_start.py b/recbole/quick_start/quick_start.py index cd0da6868..e57f3eda1 100644 --- a/recbole/quick_start/quick_start.py +++ b/recbole/quick_start/quick_start.py @@ -9,8 +9,11 @@ import logging from logging import getLogger +import torch +import pickle + from recbole.config import Config -from recbole.data import create_dataset, data_preparation +from recbole.data import create_dataset, data_preparation, save_split_dataloaders, load_split_dataloaders from recbole.utils import init_logger, get_model, get_trainer, init_seed, set_color @@ -36,10 +39,14 @@ def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=Non # dataset filtering dataset = create_dataset(config) + if config['save_dataset']: + dataset.save() logger.info(dataset) # dataset splitting train_data, valid_data, test_data = data_preparation(config, dataset) + if config['save_dataloaders']: + save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data)) # model loading and initialization model = get_model(config['model'])(config, train_data.dataset).to(config['device']) @@ -92,3 +99,55 @@ def objective_function(config_dict=None, config_file_list=None, saved=True): 'best_valid_result': best_valid_result, 'test_result': test_result } + + +def load_data_and_model(model_file, dataset_file=None, dataloader_file=None): + r"""Load filtered dataset, split dataloaders and saved model. + + Args: + model_file (str): The path of saved model file. + dataset_file (str): The path of filtered dataset. Defaults to ``None``. + dataloader_file (str): The path of split dataloaders. Defaults to ``None``. + + Note: + The :attr:`dataset` will be loaded or created according to the following strategy: + If :attr:`dataset_file` is not ``None``, the :attr:`dataset` will be loaded from :attr:`dataset_file`. + If :attr:`dataset_file` is ``None`` and :attr:`dataloader_file` is ``None``, + the :attr:`dataset` will be created according to :attr:`config`. + If :attr:`dataset_file` is ``None`` and :attr:`dataloader_file` is not ``None``, + the :attr:`dataset` will neither be loaded or created. + + The :attr:`dataloader` will be loaded or created according to the following strategy: + If :attr:`dataloader_file` is not ``None``, the :attr:`dataloader` will be loaded from :attr:`dataloader_file`. + If :attr:`dataloader_file` is ``None``, the :attr:`dataloader` will be created according to :attr:`config`. + + Returns: + tuple: + - config (Config): An instance object of Config, which record parameter information in :attr:`model_file`. + - model (AbstractRecommender): The model load from :attr:`model_file`. + - dataset (Dataset): The filtered dataset. + - train_data (AbstractDataLoader): The dataloader for training. + - valid_data (AbstractDataLoader): The dataloader for validation. + - test_data (AbstractDataLoader): The dataloader for testing. + """ + checkpoint = torch.load(model_file) + config = checkpoint['config'] + init_logger(config) + + dataset = None + if dataset_file: + with open(dataset_file, 'rb') as f: + dataset = pickle.load(f) + + if dataloader_file: + train_data, valid_data, test_data = load_split_dataloaders(dataloader_file) + else: + if dataset is None: + dataset = create_dataset(config) + train_data, valid_data, test_data = data_preparation(config, dataset) + + model = get_model(config['model'])(config, train_data.dataset).to(config['device']) + model.load_state_dict(checkpoint['state_dict']) + model.load_other_parameter(checkpoint.get('other_parameter')) + + return config, model, dataset, train_data, valid_data, test_data diff --git a/recbole/utils/argument_list.py b/recbole/utils/argument_list.py index d3438a584..9b13d6c08 100644 --- a/recbole/utils/argument_list.py +++ b/recbole/utils/argument_list.py @@ -12,7 +12,9 @@ 'data_path', 'benchmark_filename', 'show_progress', - 'config_file' + 'config_file', + 'save_dataset', + 'save_dataloaders', ] training_arguments = [ diff --git a/recbole/utils/case_study.py b/recbole/utils/case_study.py index d76f5b036..61e8e4cd0 100644 --- a/recbole/utils/case_study.py +++ b/recbole/utils/case_study.py @@ -17,20 +17,24 @@ @torch.no_grad() -def full_sort_scores(uid_series, model, test_data): +def full_sort_scores(uid_series, model, test_data, device=None): """Calculate the scores of all items for each user in uid_series. Note: The score of [pad] and history items will be set into -inf. Args: - uid_series (numpy.ndarray): User id series - model (AbstractRecommender): Model to predict - test_data (AbstractDataLoader): The test_data of model + uid_series (numpy.ndarray or list): User id series. + model (AbstractRecommender): Model to predict. + test_data (FullSortEvalDataLoader): The test_data of model. + device (torch.device): The device which model will run on. Defaults to ``None``. + Note: ``device=None`` is equivalent to ``device=torch.device('cpu')``. Returns: torch.Tensor: the scores of all items for each user in uid_series. """ + device = device or torch.device('cpu') + uid_series = np.array(uid_series) uid_field = test_data.dataset.uid_field dataset = test_data.dataset model.eval() @@ -38,23 +42,22 @@ def full_sort_scores(uid_series, model, test_data): if not test_data.is_sequential: index = np.isin(test_data.user_df[uid_field].numpy(), uid_series) input_interaction = test_data.user_df[index] - history_item = test_data.uid2history_item[input_interaction[uid_field].numpy()] + history_item = test_data.uid2history_item[uid_series] history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)]) history_col = torch.cat(list(history_item)) history_index = history_row, history_col else: - index = np.isin(test_data.uid_list, uid_series) - input_interaction = test_data.augmentation( - test_data.item_list_index[index], test_data.target_index[index], test_data.item_list_length[index] - ) + index = np.isin(dataset[uid_field].numpy(), uid_series) + input_interaction = dataset[index] history_index = None # Get scores of all items + input_interaction = input_interaction.to(device) try: scores = model.full_sort_predict(input_interaction) except NotImplementedError: input_interaction = input_interaction.repeat(dataset.item_num) - input_interaction.update(test_data.dataset.get_item_feature().repeat(len(uid_series))) + input_interaction.update(test_data.dataset.get_item_feature().to(device).repeat(len(uid_series))) scores = model.predict(input_interaction) scores = scores.view(-1, dataset.item_num) @@ -65,19 +68,24 @@ def full_sort_scores(uid_series, model, test_data): return scores -def full_sort_topk(uid_series, model, test_data, k): +def full_sort_topk(uid_series, model, test_data, k, device=None): """Calculate the top-k items' scores and ids for each user in uid_series. + Note: + The score of [pad] and history items will be set into -inf. + Args: - uid_series (numpy.ndarray): User id series - model (AbstractRecommender): Model to predict - test_data (AbstractDataLoader): The test_data of model + uid_series (numpy.ndarray): User id series. + model (AbstractRecommender): Model to predict. + test_data (FullSortEvalDataLoader): The test_data of model. k (int): The top-k items. + device (torch.device): The device which model will run on. Defaults to ``None``. + Note: ``device=None`` is equivalent to ``device=torch.device('cpu')``. Returns: tuple: - topk_scores (torch.Tensor): The scores of topk items. - topk_index (torch.Tensor): The index of topk items, which is also the internal ids of items. """ - scores = full_sort_scores(uid_series, model, test_data) + scores = full_sort_scores(uid_series, model, test_data, device) return torch.topk(scores, k) diff --git a/run_example/case_study_example.py b/run_example/case_study_example.py index 91c591929..c23ff9b09 100644 --- a/run_example/case_study_example.py +++ b/run_example/case_study_example.py @@ -11,41 +11,26 @@ import torch -from recbole.config import Config -from recbole.data import create_dataset, data_preparation -from recbole.utils import get_model, init_seed from recbole.utils.case_study import full_sort_topk, full_sort_scores +from recbole.quick_start import load_data_and_model if __name__ == '__main__': - # this part is to load saved model. - config_dict = { - # here you can set some parameters such as `gpu_id` and so on. - } - config = Config(model='BPR', dataset='ml-100k', config_dict=config_dict) - init_seed(config['seed'], config['reproducibility']) - dataset = create_dataset(config) - train_data, valid_data, test_data = data_preparation(config, dataset) - # Here you can also use `load_split_dataloaders` to load data. - # The example code for `load_split_dataloaders` can be found in `save_and_load_example.py`. - - model = get_model(config['model'])(config, train_data) - checkpoint = torch.load('RecBole/saved/BPR-Dec-08-2020_15-37-37.pth') # Here you can replace it by your model path. - model.load_state_dict(checkpoint['state_dict']) - model.load_other_parameter(checkpoint.get('other_parameter')) - model.eval() + config, model, dataset, train_data, valid_data, test_data = load_data_and_model( + model_file='../saved/BPR-Aug-20-2021_03-32-13.pth', + ) # Here you can replace it by your model path. # uid_series = np.array([1, 2]) # internal user id series # or you can use dataset.token2id to transfer external user token to internal user id - uid_series = dataset.token2id(dataset.uid_field, ['200']) + uid_series = dataset.token2id(dataset.uid_field, ['196', '186']) - topk_score, topk_iid_list = full_sort_topk(uid_series, model, test_data, k=10) + topk_score, topk_iid_list = full_sort_topk(uid_series, model, test_data, k=10, device=config['device']) print(topk_score) # scores of top 10 items print(topk_iid_list) # internal id of top 10 items - external_item_list = dataset.id2token(dataset.iid_field, topk_iid_list) + external_item_list = dataset.id2token(dataset.iid_field, topk_iid_list.cpu()) print(external_item_list) # external tokens of top 10 items print() - score = full_sort_scores(uid_series, model, test_data) + score = full_sort_scores(uid_series, model, test_data, device=config['device']) print(score) # score of all items print(score[0, dataset.token2id(dataset.iid_field, ['242', '302'])]) # score of item ['242', '302'] for user '196'. diff --git a/run_example/save_and_load_example.py b/run_example/save_and_load_example.py index 02886ea84..ce4f2c5ac 100644 --- a/run_example/save_and_load_example.py +++ b/run_example/save_and_load_example.py @@ -10,70 +10,47 @@ The path to saved data or model can be found in the output of RecBole. """ -import pickle -from logging import getLogger -import torch - -from recbole.config import Config -from recbole.data import create_dataset, data_preparation, save_split_dataloaders, load_split_dataloaders -from recbole.utils import init_seed, init_logger, get_model, get_trainer +from recbole.quick_start import run_recbole, load_data_and_model def save_example(): # configurations initialization config_dict = { - 'checkpoint_dir': '../saved' + 'checkpoint_dir': '../saved', + 'save_dataset': True, + 'save_dataloaders': True, } - config = Config(model='BPR', dataset='ml-100k', config_dict=config_dict) - init_seed(config['seed'], config['reproducibility']) - init_logger(config) - - # dataset filtering - dataset = create_dataset(config) - dataset.save('../saved/') - - # dataset splitting - train_data, valid_data, test_data = data_preparation(config, dataset) - save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data)) - - model = get_model(config['model'])(config, train_data).to(config['device']) - - # trainer loading and initialization - trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) - - # model training - # the best model will be saved in here - best_valid_score, best_valid_result = trainer.fit( - train_data, valid_data, saved=True, show_progress=config['show_progress'] - ) + run_recbole(model='BPR', dataset='ml-100k', config_dict=config_dict) def load_example(): - # configurations initialization - config_dict = { - 'checkpoint_dir': '../saved' - } - config = Config(model='BPR', dataset='ml-100k', config_dict=config_dict) - init_seed(config['seed'], config['reproducibility']) - init_logger(config) - logger = getLogger() + # Filtered dataset and split dataloaders are created according to 'config'. + config, model, dataset, train_data, valid_data, test_data = load_data_and_model( + model_file='../saved/BPR-Aug-20-2021_03-32-13.pth', + ) - with open('../saved/ml-100k-dataset.pth', 'rb') as f: # You can use your filtered data path here. - dataset = pickle.load(f) + # Filtered dataset is loaded from file, and split dataloaders are created according to 'config'. + config, model, dataset, train_data, valid_data, test_data = load_data_and_model( + model_file='../saved/BPR-Aug-20-2021_03-32-13.pth', + dataset_file='../saved/ml-100k-dataset.pth', + ) - train_data, valid_data, test_data = load_split_dataloaders('../saved/ml-100k-for-BPR-dataloader.pth') - # You can use your split data path here. + # Dataset is neither created nor loaded, and split dataloaders are loaded from file. + config, model, dataset, train_data, valid_data, test_data = load_data_and_model( + model_file='../saved/BPR-Aug-20-2021_03-32-13.pth', + dataloader_file='../saved/ml-100k-for-BPR-dataloader.pth', + ) + assert dataset is None - model = get_model(config['model'])(config, train_data).to(config['device']) - checkpoint = torch.load('../saved/BPR-Mar-20-2021_17-11-05.pth') # Here you can replace it by your model path. - model.load_state_dict(checkpoint['state_dict']) - model.load_other_parameter(checkpoint.get('other_parameter')) - logger.info(model) - logger.info(train_data.dataset) - logger.info(valid_data.dataset) - logger.info(test_data.dataset) + # Filtered dataset and split dataloaders are loaded from file. + config, model, dataset, train_data, valid_data, test_data = load_data_and_model( + model_file='../saved/BPR-Aug-20-2021_03-32-13.pth', + dataset_file='../saved/ml-100k-dataset.pth', + dataloader_file='../saved/ml-100k-for-BPR-dataloader.pth', + ) if __name__ == '__main__': save_example() + # load_example()