diff --git a/docs/source/asset/recvae.png b/docs/source/asset/recvae.png new file mode 100644 index 000000000..d0642406d Binary files /dev/null and b/docs/source/asset/recvae.png differ diff --git a/docs/source/user_guide/model/general/recvae.rst b/docs/source/user_guide/model/general/recvae.rst new file mode 100644 index 000000000..dea6a3fce --- /dev/null +++ b/docs/source/user_guide/model/general/recvae.rst @@ -0,0 +1,79 @@ +RecVAE +=========== + +Introduction +--------------------- + +`[paper] `_ + +**Title:** RecVAE: A New Variational Autoencoder for Top-N Recommendations with Implicit Feedback + +**Authors:** Ilya Shenbin, Anton Alekseev, Elena Tutubalina, Valentin Malykh, Sergey I. Nikolenko + +**Abstract:** Recent research has shown the advantages of using autoencoders based on deep neural networks for collaborative filtering. In particular, the recently proposed Mult-VAE model, which used the multinomial likelihood variational autoencoders, has shown excellent results for top-N recommendations. In this work, we propose the Recommender VAE (RecVAE) model that originates from our research on regularization techniques for variational autoencoders. RecVAE introduces several novel ideas to improve Mult-VAE, including a novel composite prior distribution for the latent codes, a new approach to setting the β hyperparameter for the β-VAE framework, and a new approach to training based on alternating updates. In experimental evaluation, we show that RecVAE significantly outperforms previously proposed autoencoder-based models, including Mult-VAE and RaCT, across classical collaborative filtering datasets, and present a detailed ablation study to assess our new developments. Code and models are available at https://github.com/ilya-shenbin/RecVAE. + +.. image:: ../../../../../asset/recvae.png + :width: 400 + :align: center + +Running with RecBole +------------------------- + +**Model Hyper-Parameters:** + +- ``hidden_dimendion (list)`` : The hidden dimension of auto-encoder. Defaults to ``600``. +- ``latent_dimendion (int)`` : The latent dimension of auto-encoder. Defaults to ``200``. +- ``dropout_prob (float)`` : The drop out probability of input. Defaults to ``0.5``. +- ``beta (float)`` : The default hyperparameter of the weight of KL loss. Defaults to ``0.2``. +- ``gamma (float)`` : The hyperparameter shared across all users. Defaults to ``0.005``. +- ``mixture_weights (list)`` : The mixture weights of three composite priors. Defaults to ``[0.15, 0.75, 0.1]``. +- ``n_enc_epochs (int)`` : The training times of encoder per epoch. Defaults to ``3``. +- ``n_dec_epochs (int)`` : The training times of decoder per epoch. Defaults to ``1``. +- ``training_neg_sample (int)`` : The negative sample num for training. Defaults to ``0``. + + +**A Running Example:** + +Write the following code to a python file, such as `run.py` + +.. code:: python + + from recbole.quick_start import run_recbole + + run_recbole(model='RecVAE', dataset='ml-100k') + +And then: + +.. code:: bash + + python run.py + +**Note**: Because this model is a non-sampling model, so you must set ``training_neg_sample=0`` when you run this model. + +Tuning Hyper Parameters +------------------------- + +If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``. + +.. code:: bash + + learning_rate choice [0.01,0.005,0.001,0.0005,0.0001] + latent_dimension choice [64,100,128,150,200,256,300,400,512] + +Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model. + +Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning: + +.. code:: bash + + python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test + +For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`. + + +If you want to change parameters, dataset or evaluation settings, take a look at + +- :doc:`../../../user_guide/config_settings` +- :doc:`../../../user_guide/data_intro` +- :doc:`../../../user_guide/evaluation_support` +- :doc:`../../../user_guide/usage` \ No newline at end of file diff --git a/docs/source/user_guide/model_intro.rst b/docs/source/user_guide/model_intro.rst index aa4ab59bf..b5a2b3546 100644 --- a/docs/source/user_guide/model_intro.rst +++ b/docs/source/user_guide/model_intro.rst @@ -30,6 +30,7 @@ General Recommendation model/general/cdae model/general/enmf model/general/nncf + model/general/recvae model/general/ease model/general/slimelastic diff --git a/recbole/data/utils.py b/recbole/data/utils.py index 76a4c0cc4..0d4fe9335 100644 --- a/recbole/data/utils.py +++ b/recbole/data/utils.py @@ -192,7 +192,8 @@ def get_data_loader(name, config, neg_sample_args): "MultiVAE": _get_AE_data_loader, 'MacridVAE': _get_AE_data_loader, 'CDAE': _get_AE_data_loader, - 'ENMF': _get_AE_data_loader + 'ENMF': _get_AE_data_loader, + 'RecVAE': _get_AE_data_loader } if config['model'] in register_table: diff --git a/recbole/model/general_recommender/__init__.py b/recbole/model/general_recommender/__init__.py index 5cbb98ca3..e0ed178b7 100644 --- a/recbole/model/general_recommender/__init__.py +++ b/recbole/model/general_recommender/__init__.py @@ -3,6 +3,8 @@ from recbole.model.general_recommender.convncf import ConvNCF from recbole.model.general_recommender.dgcf import DGCF from recbole.model.general_recommender.dmf import DMF +from recbole.model.general_recommender.ease import EASE +from recbole.model.general_recommender.enmf import ENMF from recbole.model.general_recommender.fism import FISM from recbole.model.general_recommender.gcmc import GCMC from recbole.model.general_recommender.itemknn import ItemKNN @@ -15,8 +17,9 @@ from recbole.model.general_recommender.nais import NAIS from recbole.model.general_recommender.neumf import NeuMF from recbole.model.general_recommender.ngcf import NGCF +from recbole.model.general_recommender.nncf import NNCF from recbole.model.general_recommender.pop import Pop +from recbole.model.general_recommender.recvae import RecVAE from recbole.model.general_recommender.slimelastic import SLIMElastic from recbole.model.general_recommender.spectralcf import SpectralCF -from recbole.model.general_recommender.ease import EASE -from recbole.model.general_recommender.nncf import NNCF + diff --git a/recbole/model/general_recommender/recvae.py b/recbole/model/general_recommender/recvae.py new file mode 100644 index 000000000..81f1a1cbe --- /dev/null +++ b/recbole/model/general_recommender/recvae.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- +# @Time : 2021/2/28 +# @Author : Lanling Xu +# @Email : xulanling_sherry@163.com + +r""" +RecVAE +################################################ +Reference: + Shenbin, Ilya, et al. "RecVAE: A new variational autoencoder for Top-N recommendations with implicit feedback." Proceedings of the 13th International Conference on Web Search and Data Mining. 2020. + +Reference code: + https://github.com/ilya-shenbin/RecVAE +""" + +import numpy as np +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from recbole.model.abstract_recommender import GeneralRecommender +from recbole.model.init import xavier_normal_initialization +from recbole.utils import InputType + + +def swish(x): + r"""Swish activation function: + + .. math:: + \text{Swish}(x) = \frac{x}{1 + \exp(-x)} + """ + return x.mul(torch.sigmoid(x)) + + +def log_norm_pdf(x, mu, logvar): + return -0.5 * (logvar + np.log(2 * np.pi) + (x - mu).pow(2) / logvar.exp()) + + +class CompositePrior(nn.Module): + def __init__(self, hidden_dim, latent_dim, input_dim, mixture_weights): + super(CompositePrior, self).__init__() + + self.mixture_weights = mixture_weights + + self.mu_prior = nn.Parameter(torch.Tensor(1, latent_dim), requires_grad=False) + self.mu_prior.data.fill_(0) + + self.logvar_prior = nn.Parameter(torch.Tensor(1, latent_dim), requires_grad=False) + self.logvar_prior.data.fill_(0) + + self.logvar_uniform_prior = nn.Parameter(torch.Tensor(1, latent_dim), requires_grad=False) + self.logvar_uniform_prior.data.fill_(10) + + self.encoder_old = Encoder(hidden_dim, latent_dim, input_dim) + self.encoder_old.requires_grad_(False) + + def forward(self, x, z): + post_mu, post_logvar = self.encoder_old(x, 0) + + stnd_prior = log_norm_pdf(z, self.mu_prior, self.logvar_prior) + post_prior = log_norm_pdf(z, post_mu, post_logvar) + unif_prior = log_norm_pdf(z, self.mu_prior, self.logvar_uniform_prior) + + gaussians = [stnd_prior, post_prior, unif_prior] + gaussians = [g.add(np.log(w)) for g, w in zip(gaussians, self.mixture_weights)] + + density_per_gaussian = torch.stack(gaussians, dim=-1) + + return torch.logsumexp(density_per_gaussian, dim=-1) + + +class Encoder(nn.Module): + def __init__(self, hidden_dim, latent_dim, input_dim, eps=1e-1): + super(Encoder, self).__init__() + + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.ln1 = nn.LayerNorm(hidden_dim, eps=eps) + self.fc2 = nn.Linear(hidden_dim, hidden_dim) + self.ln2 = nn.LayerNorm(hidden_dim, eps=eps) + self.fc3 = nn.Linear(hidden_dim, hidden_dim) + self.ln3 = nn.LayerNorm(hidden_dim, eps=eps) + self.fc4 = nn.Linear(hidden_dim, hidden_dim) + self.ln4 = nn.LayerNorm(hidden_dim, eps=eps) + self.fc5 = nn.Linear(hidden_dim, hidden_dim) + self.ln5 = nn.LayerNorm(hidden_dim, eps=eps) + self.fc_mu = nn.Linear(hidden_dim, latent_dim) + self.fc_logvar = nn.Linear(hidden_dim, latent_dim) + + def forward(self, x, dropout_prob): + x = F.normalize(x) + x = F.dropout(x, dropout_prob, training=self.training) + + h1 = self.ln1(swish(self.fc1(x))) + h2 = self.ln2(swish(self.fc2(h1) + h1)) + h3 = self.ln3(swish(self.fc3(h2) + h1 + h2)) + h4 = self.ln4(swish(self.fc4(h3) + h1 + h2 + h3)) + h5 = self.ln5(swish(self.fc5(h4) + h1 + h2 + h3 + h4)) + return self.fc_mu(h5), self.fc_logvar(h5) + + +class RecVAE(GeneralRecommender): + r"""Collaborative Denoising Auto-Encoder (RecVAE) is a recommendation model + for top-N recommendation with implicit feedback. + + We implement the model following the original author + """ + input_type = InputType.PAIRWISE + + def __init__(self, config, dataset): + super(RecVAE, self).__init__(config, dataset) + + self.hidden_dim = config["hidden_dimension"] + self.latent_dim = config['latent_dimension'] + self.dropout_prob = config['dropout_prob'] + self.beta = config['beta'] + self.mixture_weights = config['mixture_weights'] + self.gamma = config['gamma'] + + self.history_item_id, self.history_item_value, _ = dataset.history_item_matrix() + self.history_item_id = self.history_item_id.to(self.device) + self.history_item_value = self.history_item_value.to(self.device) + + self.encoder = Encoder(self.hidden_dim, self.latent_dim, self.n_items) + self.prior = CompositePrior(self.hidden_dim, self.latent_dim, self.n_items, self.mixture_weights) + self.decoder = nn.Linear(self.latent_dim, self.n_items) + + # parameters initialization + self.apply(xavier_normal_initialization) + + def get_rating_matrix(self, user): + r"""Get a batch of user's feature with the user's id and history interaction matrix. + + Args: + user (torch.LongTensor): The input tensor that contains user's id, shape: [batch_size, ] + + Returns: + torch.FloatTensor: The user's feature of a batch of user, shape: [batch_size, n_items] + """ + # Following lines construct tensor of shape [B,n_items] using the tensor of shape [B,H] + col_indices = self.history_item_id[user].flatten() + row_indices = torch.arange(user.shape[0]).to(self.device) \ + .repeat_interleave(self.history_item_id.shape[1], dim=0) + rating_matrix = torch.zeros(1).to(self.device).repeat(user.shape[0], self.n_items) + rating_matrix.index_put_((row_indices, col_indices), self.history_item_value[user].flatten()) + return rating_matrix + + def reparameterize(self, mu, logvar): + if self.training: + std = torch.exp(0.5 * logvar) + epsilon = torch.zeros_like(std).normal_(mean=0, std=0.01) + return mu + epsilon * std + else: + return mu + + def forward(self, rating_matrix, dropout_prob): + mu, logvar = self.encoder(rating_matrix, dropout_prob=dropout_prob) + z = self.reparameterize(mu, logvar) + x_pred = self.decoder(z) + return x_pred, mu, logvar, z + + def calculate_loss(self, interaction, encoder_flag): + user = interaction[self.USER_ID] + rating_matrix = self.get_rating_matrix(user) + if encoder_flag: + dropout_prob = self.dropout_prob + else: + dropout_prob = 0 + x_pred, mu, logvar, z = self.forward(rating_matrix, dropout_prob) + + if self.gamma: + norm = rating_matrix.sum(dim=-1) + kl_weight = self.gamma * norm + else: + kl_weight = self.beta + + mll = (F.log_softmax(x_pred, dim=-1) * rating_matrix).sum(dim=-1).mean() + kld = (log_norm_pdf(z, mu, logvar) - self.prior(rating_matrix, z)).sum(dim=-1).mul(kl_weight).mean() + negative_elbo = -(mll - kld) + + return negative_elbo + + def predict(self, interaction): + user = interaction[self.USER_ID] + item = interaction[self.ITEM_ID] + + rating_matrix = self.get_rating_matrix(user) + + scores, _, _, _ = self.forward(rating_matrix, self.dropout_prob) + + return scores[[user, item]] + + def full_sort_predict(self, interaction): + user = interaction[self.USER_ID] + + rating_matrix = self.get_rating_matrix(user) + + scores, _, _, _ = self.forward(rating_matrix, self.dropout_prob) + + return scores.view(-1) + + def update_prior(self): + self.prior.encoder_old.load_state_dict(deepcopy(self.encoder.state_dict())) diff --git a/recbole/properties/model/MultiVAE.yaml b/recbole/properties/model/MultiVAE.yaml index 63021d9b0..fb5e88847 100644 --- a/recbole/properties/model/MultiVAE.yaml +++ b/recbole/properties/model/MultiVAE.yaml @@ -2,4 +2,4 @@ mlp_hidden_size: [600] latent_dimension: 128 dropout_prob: 0.5 anneal_cap: 0.2 -total_anneal_steps: 200000 \ No newline at end of file +total_anneal_steps: 200000 diff --git a/recbole/properties/model/RecVAE.yaml b/recbole/properties/model/RecVAE.yaml new file mode 100644 index 000000000..5e33e9b6b --- /dev/null +++ b/recbole/properties/model/RecVAE.yaml @@ -0,0 +1,8 @@ +hidden_dimension: 600 +latent_dimension: 200 +dropout_prob: 0.5 +beta: 0.2 +mixture_weights: [0.15, 0.75, 0.1] +gamma: 0.005 +n_enc_epochs: 3 +n_dec_epochs: 1 diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index 9f9bfd9b5..fdd61f1f6 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -3,9 +3,9 @@ # @Email : slmu@ruc.edu.cn # UPDATE: -# @Time : 2020/8/7, 2020/9/26, 2020/9/26, 2020/10/01, 2020/9/16, 2020/10/8, 2020/10/15, 2020/11/20, 2021/2/20 -# @Author : Zihan Lin, Yupeng Hou, Yushuo Chen, Shanlei Mu, Xingyu Pan, Hui Wang, Xinyan Fan, Chen Yang, Yibo Li -# @Email : linzihan.super@foxmail.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, slmu@ruc.edu.cn, panxy@ruc.edu.cn, hui.wang@ruc.edu.cn, xinyan.fan@ruc.edu.cn, 254170321@qq.com, 2018202152@ruc.edu.cn +# @Time : 2020/8/7, 2020/9/26, 2020/9/26, 2020/10/01, 2020/9/16, 2020/10/8, 2020/10/15, 2020/11/20, 2021/2/20, 2021/3/3 +# @Author : Zihan Lin, Yupeng Hou, Yushuo Chen, Shanlei Mu, Xingyu Pan, Hui Wang, Xinyan Fan, Chen Yang, Yibo Li, Lanling Xu +# @Email : linzihan.super@foxmail.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, slmu@ruc.edu.cn, panxy@ruc.edu.cn, hui.wang@ruc.edu.cn, xinyan.fan@ruc.edu.cn, 254170321@qq.com, 2018202152@ruc.edu.cn, xulanling_sherry@163.com r""" recbole.trainer.trainer @@ -93,33 +93,33 @@ def __init__(self, config, model): self.best_valid_score = -1 self.best_valid_result = None self.train_loss_dict = dict() - self.optimizer = self._build_optimizer() + self.optimizer = self._build_optimizer(self.model.parameters()) self.eval_type = config['eval_type'] self.evaluator = ProxyEvaluator(config) self.item_tensor = None self.tot_item_num = None - def _build_optimizer(self): + def _build_optimizer(self, params): r"""Init the Optimizer Returns: torch.optim: the optimizer """ if self.learner.lower() == 'adam': - optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + optimizer = optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) elif self.learner.lower() == 'sgd': - optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + optimizer = optim.SGD(params, lr=self.learning_rate, weight_decay=self.weight_decay) elif self.learner.lower() == 'adagrad': - optimizer = optim.Adagrad(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + optimizer = optim.Adagrad(params, lr=self.learning_rate, weight_decay=self.weight_decay) elif self.learner.lower() == 'rmsprop': - optimizer = optim.RMSprop(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + optimizer = optim.RMSprop(params, lr=self.learning_rate, weight_decay=self.weight_decay) elif self.learner.lower() == 'sparse_adam': - optimizer = optim.SparseAdam(self.model.parameters(), lr=self.learning_rate) + optimizer = optim.SparseAdam(params, lr=self.learning_rate) if self.weight_decay > 0: self.logger.warning('Sparse Adam cannot argument received argument [{weight_decay}]') else: self.logger.warning('Received unrecognized optimizer, set default Adam optimizer') - optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) + optimizer = optim.Adam(params, lr=self.learning_rate) return optimizer def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False): @@ -848,3 +848,114 @@ def evaluate(self, eval_data, load_best_model=True, model_file=None, show_progre batch_matrix_list = [[torch.stack((self.eval_true, self.eval_pred), 1)]] result = self.evaluator.evaluate(batch_matrix_list, eval_data) return result + + +class RecVAETrainer(Trainer): + r"""RecVAETrainer is designed for RecVAE, which is a general recommender. + + """ + + def __init__(self, config, model): + super(RecVAETrainer, self).__init__(config, model) + self.n_enc_epochs = config['n_enc_epochs'] + self.n_dec_epochs = config['n_dec_epochs'] + + def _train_epoch(self, train_data, epoch_idx, n_epochs, optimizer, encoder_flag, loss_func=None, show_progress=False): + self.model.train() + loss_func = loss_func or self.model.calculate_loss + total_loss = None + iter_data = ( + tqdm( + enumerate(train_data), + total=len(train_data), + desc=f"Train {epoch_idx:>5}", + ) if show_progress else enumerate(train_data) + ) + for epoch in range(n_epochs): + for batch_idx, interaction in iter_data: + interaction = interaction.to(self.device) + optimizer.zero_grad() + losses = loss_func(interaction, encoder_flag=encoder_flag) + if isinstance(losses, tuple): + loss = sum(losses) + loss_tuple = tuple(per_loss.item() for per_loss in losses) + total_loss = loss_tuple if total_loss is None else tuple(map(sum, zip(total_loss, loss_tuple))) + else: + loss = losses + total_loss = losses.item() if total_loss is None else total_loss + losses.item() + self._check_nan(loss) + loss.backward() + if self.clip_grad_norm: + clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm) + optimizer.step() + + return total_loss + + def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None): + if saved and self.start_epoch >= self.epochs: + self._save_checkpoint(-1) + + encoder_params = set(self.model.encoder.parameters()) + decoder_params = set(self.model.decoder.parameters()) + + optimizer_encoder = self._build_optimizer(encoder_params) + optimizer_decoder = self._build_optimizer(decoder_params) + + for epoch_idx in range(self.start_epoch, self.epochs): + # alternate training + training_start_time = time() + train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress, + n_epochs=self.n_enc_epochs, encoder_flag=True, optimizer=optimizer_encoder) + self.model.update_prior() + train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress, + n_epochs=self.n_dec_epochs, encoder_flag=False, optimizer=optimizer_decoder) + self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss + training_end_time = time() + train_loss_output = \ + self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) + if verbose: + self.logger.info(train_loss_output) + + # eval + if self.eval_step <= 0 or not valid_data: + if saved: + self._save_checkpoint(epoch_idx) + update_output = 'Saving current: %s' % self.saved_model_file + if verbose: + self.logger.info(update_output) + continue + if (epoch_idx + 1) % self.eval_step == 0: + valid_start_time = time() + valid_score, valid_result = self._valid_epoch(valid_data, show_progress=show_progress) + self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping( + valid_score, + self.best_valid_score, + self.cur_step, + max_step=self.stopping_step, + bigger=self.valid_metric_bigger + ) + valid_end_time = time() + valid_score_output = "epoch %d evaluating [time: %.2fs, valid_score: %f]" % \ + (epoch_idx, valid_end_time - valid_start_time, valid_score) + valid_result_output = 'valid result: \n' + dict2str(valid_result) + if verbose: + self.logger.info(valid_score_output) + self.logger.info(valid_result_output) + if update_flag: + if saved: + self._save_checkpoint(epoch_idx) + update_output = 'Saving current best: %s' % self.saved_model_file + if verbose: + self.logger.info(update_output) + self.best_valid_result = valid_result + + if callback_fn: + callback_fn(epoch_idx, valid_score) + + if stop_flag: + stop_output = 'Finished training, best eval result in epoch %d' % \ + (epoch_idx - self.cur_step * self.eval_step) + if verbose: + self.logger.info(stop_output) + break + return self.best_valid_score, self.best_valid_result diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 54337bac5..b65d3a6fd 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -186,7 +186,13 @@ def test_NNCF(self): 'model': 'NNCF', } quick_test(config_dict) - + + def test_RecVAE(self): + config_dict = { + 'model': 'RecVAE', + 'training_neg_sample_num': 0 + } + quick_test(config_dict) def test_slimelastic(self): config_dict = {