Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: add RecVAE model #727

Merged
merged 88 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d894e41
add space to README.md
guoyihonggyh Jan 15, 2021
c79d782
Merge pull request #1 from RUCAIBox/master
Sherry-XLL Jan 17, 2021
af33242
Merge pull request #2 from RUCAIBox/master
Sherry-XLL Jan 29, 2021
187c6d1
ADD RecVAE model
Yibo-Li-1 Feb 2, 2021
94606d7
ADD RecVAE
Yibo-Li-1 Feb 4, 2021
7b0543e
Merge pull request #3 from RUCAIBox/master
Sherry-XLL Feb 4, 2021
0adc68c
FIX: modify RecBole model
Yibo-Li-1 Feb 10, 2021
fb0779d
Merge branch 'master' of https://github.com/Sherry-XLL/RecBole
Yibo-Li-1 Feb 10, 2021
b4897ad
Merge pull request #4 from RUCAIBox/master
Sherry-XLL Feb 10, 2021
9ce4faf
FEA: Add CI model test of RecVAE
Sherry-XLL Feb 20, 2021
4e4b76e
ADD: add RecVAE model
Sherry-XLL Feb 20, 2021
af6ff7e
add RecVAE data loader
Sherry-XLL Feb 20, 2021
4c07e59
ADD: add RecVAE data loader
Sherry-XLL Feb 20, 2021
0d30675
Update README.md
Sherry-XLL Feb 20, 2021
95d4613
Update: update recvae.py
Sherry-XLL Feb 20, 2021
affe848
Update: update RecVAE.yaml
Sherry-XLL Feb 20, 2021
bdffdd3
Update: update recvae.py
Sherry-XLL Feb 21, 2021
d63c022
Update: update RecVAE.yaml
Sherry-XLL Feb 21, 2021
e4a71cb
FIX: Update recvae.py
Sherry-XLL Feb 26, 2021
349fa7b
FIX: update hyper parameter of RecVAE
Sherry-XLL Feb 21, 2021
4c950f1
FIX: resolve conflicts
Sherry-XLL Feb 27, 2021
b7e1ef7
FIX: resolve conflicts
Sherry-XLL Feb 27, 2021
108352e
FEA: ENMF.yaml
linzihan-backforward Jan 1, 2021
8525fa2
FIX: rename in ENMF
linzihan-backforward Jan 1, 2021
f6a707f
FIX: fix test_model_auto.py
Sherry-XLL Feb 27, 2021
ab3174c
FIX: auto_test enmf
linzihan-backforward Jan 1, 2021
3b1b7c8
FIX: auto_test
linzihan-backforward Jan 1, 2021
5c5652a
FIX: auto_test and trainer
Sherry-XLL Feb 27, 2021
4dda540
FIX: change self.group_by into self.group_field (self.group_by is a m…
chenyushuo Feb 19, 2021
b1fc6d3
FIX: fix runtime error when ContextFullDataLoader is empty.
chenyushuo Feb 19, 2021
a1943e0
add parameter check
Yibo-Li-1 Feb 16, 2021
40ebfd9
add parameter check
Guan-JW Feb 16, 2021
c3bc871
add parameter check
Guan-JW Feb 16, 2021
e77a28c
add parameter check
Guan-JW Feb 16, 2021
953f63f
add parameter check
Guan-JW Feb 16, 2021
d453c63
add parameter check
Guan-JW Feb 17, 2021
05daa85
FIX: revert the configurator
2017pxy Feb 18, 2021
b9124c4
FIX: set scipy version to 1.6.0
2017pxy Feb 19, 2021
3b2c066
FIX: remove print
2017pxy Feb 19, 2021
61a6df2
FEA: add test dataset
2017pxy Nov 17, 2020
8b4545d
FEA: rebuild test dataset and test workflow, and give up unnessary tests
2017pxy Nov 17, 2020
8f0bfff
FIX: resolve conflicts in test_model_auto.py
Sherry-XLL Feb 27, 2021
0ed4a7c
FIX: typos in doc for evaluator
guijiql Feb 21, 2021
5f738bb
FIX: code format
guijiql Feb 21, 2021
bda44fa
Update python-package.yml
guijiql Feb 21, 2021
fd48dba
FEA: Add docs/ to RecBole
chenyushuo Feb 25, 2021
f8e551a
FIX: resolve conflicts between EASE and RecVAE
Sherry-XLL Feb 27, 2021
8e14274
Reduce memory usage
deklanw Dec 25, 2020
e7d1a11
Use csr instead of coo to allow slicing
deklanw Dec 25, 2020
5606e0a
Remove extraneous function
deklanw Dec 25, 2020
b0f2037
add other parameters output
Guan-JW Feb 20, 2021
ea6c7d5
update readme.md
Yibo-Li-1 Jan 16, 2021
37fa682
update readme
Yibo-Li-1 Jan 16, 2021
7d33e48
update readme.md
Yibo-Li-1 Jan 16, 2021
1bb4d38
update readme.md
Yibo-Li-1 Jan 16, 2021
99797dd
FEA:Draw graphs of train loss
Yibo-Li-1 Feb 20, 2021
098e2c6
FEA:Draw graphs of train loss
Yibo-Li-1 Feb 20, 2021
33f0565
FEA:Draw graphs of train loss
Yibo-Li-1 Feb 20, 2021
9fe57f8
FEA:Draw graphs of train loss
Yibo-Li-1 Feb 20, 2021
e328221
FEA:Draw graphs of train loss
Yibo-Li-1 Feb 20, 2021
357d239
delete pdf
Yibo-Li-1 Feb 21, 2021
e049f6e
FEA: add new parameter draw_pic
2017pxy Feb 21, 2021
6573f5d
FIX: resolve conflicts between NNCF and RecVAE
Sherry-XLL Feb 27, 2021
40c3a0c
add nncf
cyLi-Tiger Feb 15, 2021
0b40431
add nncf
cyLi-Tiger Feb 15, 2021
f9947b5
FIX: resolve conflict in run_test_example.py
Sherry-XLL Feb 27, 2021
489c679
FIX: resolve conflict in test_model_auto.py
Sherry-XLL Feb 27, 2021
b3ff667
add packages nncf need
cyLi-Tiger Feb 15, 2021
35cd124
add nncf
cyLi-Tiger Feb 15, 2021
4ac55e8
FIX
cyLi-Tiger Feb 23, 2021
298c586
FIX: update nncf
cyLi-Tiger Feb 23, 2021
aea941a
FIX: update nncf
cyLi-Tiger Feb 23, 2021
3297e6f
FIX: update nncf
cyLi-Tiger Feb 23, 2021
7e6d85c
FIX: resolve conflicts in RecVAE
Sherry-XLL Feb 27, 2021
6a14ba4
Merge branch 'RUCAIBox-0.2.x'
Sherry-XLL Feb 27, 2021
2aa5796
Update MultiVAE.yaml
Sherry-XLL Feb 27, 2021
5dfab27
Update trainer.py
Sherry-XLL Feb 27, 2021
9d195f7
Update trainer.py
Sherry-XLL Feb 27, 2021
f40bef2
Update trainer.py
Sherry-XLL Feb 27, 2021
c7d63e8
Update trainer.py
Sherry-XLL Feb 27, 2021
a8bce8f
FIX: Update RecVAETrainer
Sherry-XLL Feb 27, 2021
8a57c0b
Merge branch 'master' of https://github.com/Sherry-XLL/RecBole
Sherry-XLL Feb 27, 2021
5ca9127
FIX: Update RecVAETrainer
Sherry-XLL Feb 27, 2021
7ab404c
FEA: Add doc of RecVAE
Sherry-XLL Feb 28, 2021
246d621
Merge branch 'master' of git://github.com/Sherry-XLL/RecBole
Sherry-XLL Feb 28, 2021
5472f85
FIX: Update RecVAE model
Mar 2, 2021
23a1ab1
FIX: Update trainer.py
Sherry-XLL Mar 3, 2021
89b214a
Merge branch '0.2.x' into master
2017pxy Mar 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/asset/recvae.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
79 changes: 79 additions & 0 deletions docs/source/user_guide/model/general/recvae.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
RecVAE
===========

Introduction
---------------------

`[paper] <https://dl.acm.org/doi/10.1145/3336191.3371831>`_

**Title:** RecVAE: A New Variational Autoencoder for Top-N Recommendations with Implicit Feedback

**Authors:** Ilya Shenbin, Anton Alekseev, Elena Tutubalina, Valentin Malykh, Sergey I. Nikolenko

**Abstract:** Recent research has shown the advantages of using autoencoders based on deep neural networks for collaborative filtering. In particular, the recently proposed Mult-VAE model, which used the multinomial likelihood variational autoencoders, has shown excellent results for top-N recommendations. In this work, we propose the Recommender VAE (RecVAE) model that originates from our research on regularization techniques for variational autoencoders. RecVAE introduces several novel ideas to improve Mult-VAE, including a novel composite prior distribution for the latent codes, a new approach to setting the β hyperparameter for the β-VAE framework, and a new approach to training based on alternating updates. In experimental evaluation, we show that RecVAE significantly outperforms previously proposed autoencoder-based models, including Mult-VAE and RaCT, across classical collaborative filtering datasets, and present a detailed ablation study to assess our new developments. Code and models are available at https://github.com/ilya-shenbin/RecVAE.

.. image:: ../../../../../asset/recvae.png
:width: 400
:align: center

Running with RecBole
-------------------------

**Model Hyper-Parameters:**

- ``hidden_dimendion (list)`` : The hidden dimension of auto-encoder. Defaults to ``600``.
- ``latent_dimendion (int)`` : The latent dimension of auto-encoder. Defaults to ``200``.
- ``dropout_prob (float)`` : The drop out probability of input. Defaults to ``0.5``.
- ``beta (float)`` : The default hyperparameter of the weight of KL loss. Defaults to ``0.2``.
- ``gamma (float)`` : The hyperparameter shared across all users. Defaults to ``0.005``.
- ``mixture_weights (list)`` : The mixture weights of three composite priors. Defaults to ``[0.15, 0.75, 0.1]``.
- ``n_enc_epochs (int)`` : The training times of encoder per epoch. Defaults to ``3``.
- ``n_dec_epochs (int)`` : The training times of decoder per epoch. Defaults to ``1``.
- ``training_neg_sample (int)`` : The negative sample num for training. Defaults to ``0``.


**A Running Example:**

Write the following code to a python file, such as `run.py`

.. code:: python

from recbole.quick_start import run_recbole

run_recbole(model='RecVAE', dataset='ml-100k')

And then:

.. code:: bash

python run.py

**Note**: Because this model is a non-sampling model, so you must set ``training_neg_sample=0`` when you run this model.

Tuning Hyper Parameters
-------------------------

If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``.

.. code:: bash

learning_rate choice [0.01,0.005,0.001,0.0005,0.0001]
latent_dimension choice [64,100,128,150,200,256,300,400,512]

Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model.

Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning:

.. code:: bash

python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test

For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`.


If you want to change parameters, dataset or evaluation settings, take a look at

- :doc:`../../../user_guide/config_settings`
- :doc:`../../../user_guide/data_intro`
- :doc:`../../../user_guide/evaluation_support`
- :doc:`../../../user_guide/usage`
1 change: 1 addition & 0 deletions docs/source/user_guide/model_intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ General Recommendation
model/general/cdae
model/general/enmf
model/general/nncf
model/general/recvae
model/general/ease
model/general/slimelastic

Expand Down
3 changes: 2 additions & 1 deletion recbole/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def get_data_loader(name, config, neg_sample_args):
"MultiVAE": _get_AE_data_loader,
'MacridVAE': _get_AE_data_loader,
'CDAE': _get_AE_data_loader,
'ENMF': _get_AE_data_loader
'ENMF': _get_AE_data_loader,
'RecVAE': _get_AE_data_loader
}

if config['model'] in register_table:
Expand Down
7 changes: 5 additions & 2 deletions recbole/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from recbole.model.general_recommender.convncf import ConvNCF
from recbole.model.general_recommender.dgcf import DGCF
from recbole.model.general_recommender.dmf import DMF
from recbole.model.general_recommender.ease import EASE
from recbole.model.general_recommender.enmf import ENMF
from recbole.model.general_recommender.fism import FISM
from recbole.model.general_recommender.gcmc import GCMC
from recbole.model.general_recommender.itemknn import ItemKNN
Expand All @@ -15,8 +17,9 @@
from recbole.model.general_recommender.nais import NAIS
from recbole.model.general_recommender.neumf import NeuMF
from recbole.model.general_recommender.ngcf import NGCF
from recbole.model.general_recommender.nncf import NNCF
from recbole.model.general_recommender.pop import Pop
from recbole.model.general_recommender.recvae import RecVAE
from recbole.model.general_recommender.slimelastic import SLIMElastic
from recbole.model.general_recommender.spectralcf import SpectralCF
from recbole.model.general_recommender.ease import EASE
from recbole.model.general_recommender.nncf import NNCF

204 changes: 204 additions & 0 deletions recbole/model/general_recommender/recvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# -*- coding: utf-8 -*-
# @Time : 2021/2/28
# @Author : Lanling Xu
# @Email : [email protected]

r"""
RecVAE
################################################
Reference:
Shenbin, Ilya, et al. "RecVAE: A new variational autoencoder for Top-N recommendations with implicit feedback." Proceedings of the 13th International Conference on Web Search and Data Mining. 2020.

Reference code:
https://github.com/ilya-shenbin/RecVAE
"""

import numpy as np
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.init import xavier_normal_initialization
from recbole.utils import InputType


def swish(x):
r"""Swish activation function:

.. math::
\text{Swish}(x) = \frac{x}{1 + \exp(-x)}
"""
return x.mul(torch.sigmoid(x))


def log_norm_pdf(x, mu, logvar):
return -0.5 * (logvar + np.log(2 * np.pi) + (x - mu).pow(2) / logvar.exp())


class CompositePrior(nn.Module):
def __init__(self, hidden_dim, latent_dim, input_dim, mixture_weights):
super(CompositePrior, self).__init__()

self.mixture_weights = mixture_weights

self.mu_prior = nn.Parameter(torch.Tensor(1, latent_dim), requires_grad=False)
self.mu_prior.data.fill_(0)

self.logvar_prior = nn.Parameter(torch.Tensor(1, latent_dim), requires_grad=False)
self.logvar_prior.data.fill_(0)

self.logvar_uniform_prior = nn.Parameter(torch.Tensor(1, latent_dim), requires_grad=False)
self.logvar_uniform_prior.data.fill_(10)

self.encoder_old = Encoder(hidden_dim, latent_dim, input_dim)
self.encoder_old.requires_grad_(False)

def forward(self, x, z):
post_mu, post_logvar = self.encoder_old(x, 0)

stnd_prior = log_norm_pdf(z, self.mu_prior, self.logvar_prior)
post_prior = log_norm_pdf(z, post_mu, post_logvar)
unif_prior = log_norm_pdf(z, self.mu_prior, self.logvar_uniform_prior)

gaussians = [stnd_prior, post_prior, unif_prior]
gaussians = [g.add(np.log(w)) for g, w in zip(gaussians, self.mixture_weights)]

density_per_gaussian = torch.stack(gaussians, dim=-1)

return torch.logsumexp(density_per_gaussian, dim=-1)


class Encoder(nn.Module):
def __init__(self, hidden_dim, latent_dim, input_dim, eps=1e-1):
super(Encoder, self).__init__()

self.fc1 = nn.Linear(input_dim, hidden_dim)
self.ln1 = nn.LayerNorm(hidden_dim, eps=eps)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.ln2 = nn.LayerNorm(hidden_dim, eps=eps)
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
self.ln3 = nn.LayerNorm(hidden_dim, eps=eps)
self.fc4 = nn.Linear(hidden_dim, hidden_dim)
self.ln4 = nn.LayerNorm(hidden_dim, eps=eps)
self.fc5 = nn.Linear(hidden_dim, hidden_dim)
self.ln5 = nn.LayerNorm(hidden_dim, eps=eps)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

def forward(self, x, dropout_prob):
x = F.normalize(x)
x = F.dropout(x, dropout_prob, training=self.training)

h1 = self.ln1(swish(self.fc1(x)))
h2 = self.ln2(swish(self.fc2(h1) + h1))
h3 = self.ln3(swish(self.fc3(h2) + h1 + h2))
h4 = self.ln4(swish(self.fc4(h3) + h1 + h2 + h3))
h5 = self.ln5(swish(self.fc5(h4) + h1 + h2 + h3 + h4))
return self.fc_mu(h5), self.fc_logvar(h5)


class RecVAE(GeneralRecommender):
r"""Collaborative Denoising Auto-Encoder (RecVAE) is a recommendation model
for top-N recommendation with implicit feedback.

We implement the model following the original author
"""
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(RecVAE, self).__init__(config, dataset)

self.hidden_dim = config["hidden_dimension"]
self.latent_dim = config['latent_dimension']
self.dropout_prob = config['dropout_prob']
self.beta = config['beta']
self.mixture_weights = config['mixture_weights']
self.gamma = config['gamma']

self.history_item_id, self.history_item_value, _ = dataset.history_item_matrix()
self.history_item_id = self.history_item_id.to(self.device)
self.history_item_value = self.history_item_value.to(self.device)

self.encoder = Encoder(self.hidden_dim, self.latent_dim, self.n_items)
self.prior = CompositePrior(self.hidden_dim, self.latent_dim, self.n_items, self.mixture_weights)
self.decoder = nn.Linear(self.latent_dim, self.n_items)

# parameters initialization
self.apply(xavier_normal_initialization)

def get_rating_matrix(self, user):
r"""Get a batch of user's feature with the user's id and history interaction matrix.

Args:
user (torch.LongTensor): The input tensor that contains user's id, shape: [batch_size, ]

Returns:
torch.FloatTensor: The user's feature of a batch of user, shape: [batch_size, n_items]
"""
# Following lines construct tensor of shape [B,n_items] using the tensor of shape [B,H]
col_indices = self.history_item_id[user].flatten()
row_indices = torch.arange(user.shape[0]).to(self.device) \
.repeat_interleave(self.history_item_id.shape[1], dim=0)
rating_matrix = torch.zeros(1).to(self.device).repeat(user.shape[0], self.n_items)
rating_matrix.index_put_((row_indices, col_indices), self.history_item_value[user].flatten())
return rating_matrix

def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
epsilon = torch.zeros_like(std).normal_(mean=0, std=0.01)
return mu + epsilon * std
else:
return mu

def forward(self, rating_matrix, dropout_prob):
mu, logvar = self.encoder(rating_matrix, dropout_prob=dropout_prob)
z = self.reparameterize(mu, logvar)
x_pred = self.decoder(z)
return x_pred, mu, logvar, z

def calculate_loss(self, interaction, encoder_flag):
user = interaction[self.USER_ID]
rating_matrix = self.get_rating_matrix(user)
if encoder_flag:
dropout_prob = self.dropout_prob
else:
dropout_prob = 0
x_pred, mu, logvar, z = self.forward(rating_matrix, dropout_prob)

if self.gamma:
norm = rating_matrix.sum(dim=-1)
kl_weight = self.gamma * norm
else:
kl_weight = self.beta

mll = (F.log_softmax(x_pred, dim=-1) * rating_matrix).sum(dim=-1).mean()
kld = (log_norm_pdf(z, mu, logvar) - self.prior(rating_matrix, z)).sum(dim=-1).mul(kl_weight).mean()
negative_elbo = -(mll - kld)

return negative_elbo

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

rating_matrix = self.get_rating_matrix(user)

scores, _, _, _ = self.forward(rating_matrix, self.dropout_prob)

return scores[[user, item]]

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]

rating_matrix = self.get_rating_matrix(user)

scores, _, _, _ = self.forward(rating_matrix, self.dropout_prob)

return scores.view(-1)

def update_prior(self):
self.prior.encoder_old.load_state_dict(deepcopy(self.encoder.state_dict()))
2 changes: 1 addition & 1 deletion recbole/properties/model/MultiVAE.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ mlp_hidden_size: [600]
latent_dimension: 128
dropout_prob: 0.5
anneal_cap: 0.2
total_anneal_steps: 200000
total_anneal_steps: 200000
8 changes: 8 additions & 0 deletions recbole/properties/model/RecVAE.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
hidden_dimension: 600
latent_dimension: 200
dropout_prob: 0.5
beta: 0.2
mixture_weights: [0.15, 0.75, 0.1]
gamma: 0.005
n_enc_epochs: 3
n_dec_epochs: 1
Loading