Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: add RaCT model #732

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions recbole/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# @Email : [email protected]

# UPDATE:
# @Time : 2020/10/19, 2020/9/17, 2020/8/31
# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li
# @Email : [email protected], [email protected], [email protected]
# @Time : 2020/10/19, 2020/9/17, 2020/8/31, 2021/2/20
# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li, Haoran Cheng
# @Email : [email protected], [email protected], [email protected], [email protected]

"""
recbole.data.utils
Expand Down Expand Up @@ -237,7 +237,8 @@ def get_data_loader(name, config, eval_setting):
"MultiVAE": _get_AE_data_loader,
'MacridVAE': _get_AE_data_loader,
'CDAE': _get_AE_data_loader,
'ENMF': _get_AE_data_loader
'ENMF': _get_AE_data_loader,
'RaCT': _get_AE_data_loader
}

if config['model'] in register_table:
Expand Down
251 changes: 251 additions & 0 deletions recbole/model/general_recommender/ract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# -*- coding: utf-8 -*-
# @Time : 2021/2/16
# @Author : Haoran Cheng
# @Email : [email protected]

r"""
RaCT
################################################
Reference:
Sam Lobel et al. "RaCT: Towards Amortized Ranking-Critical Training for Collaborative Filtering." in ICLR 2020.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.init import xavier_normal_initialization
from recbole.utils import InputType


class RaCT(GeneralRecommender):
r"""RaCT is a item-based model collaborative filtering model that simultaneously rank all items for user .

We implement the RaCT model with only user dataloader.
"""
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(RaCT, self).__init__(config, dataset)

self.layers = config["mlp_hidden_size"]
self.lat_dim = config['latent_dimension']
self.drop_out = config['dropout_prob']
self.anneal_cap = config['anneal_cap']
self.total_anneal_steps = config["total_anneal_steps"]

self.history_item_id, self.history_item_value, _ = dataset.history_item_matrix()
self.history_item_id = self.history_item_id.to(self.device)
self.history_item_value = self.history_item_value.to(self.device)

self.update = 0

self.encode_layer_dims = [self.n_items] + self.layers + [self.lat_dim]
self.decode_layer_dims = [int(self.lat_dim / 2)] + self.encode_layer_dims[::-1][1:]

self.encoder = self.mlp_layers(self.encode_layer_dims)
self.decoder = self.mlp_layers(self.decode_layer_dims)

self.critic_layers_1 = config["critic_layers_1"]
self.critic_layers_2 = config["critic_layers_2"]
self.critic_layers_3 = config["critic_layers_3"]
self.metrics_k = config["metrics_k"]
self.number_of_seen_items = 0
self.number_of_unseen_items = 0

self.input_matrix = None
self.predict_matrix = None
self.true_matrix = None
self.critic_net = self.construct_critic_layers()

self.train_stage = config['train_stage']
self.pre_model_path = config['pre_model_path']

# parameters initialization
assert self.train_stage in ['actor_pretrain', 'critic_pretrain', 'finetune']
if self.train_stage == 'actor_pretrain':
self.apply(xavier_normal_initialization)
for p in self.critic_net.parameters():
p.requires_grad = False
elif self.train_stage == 'critic_pretrain':
# load pretrained model for finetune
pretrained = torch.load(self.pre_model_path)
self.logger.info('Load pretrained model from', self.pre_model_path)
self.load_state_dict(pretrained['state_dict'])
for p in self.encoder.parameters():
p.requires_grad = False
for p in self.decoder.parameters():
p.requires_grad = False
else:
# load pretrained model for finetune
pretrained = torch.load(self.pre_model_path)
self.logger.info('Load pretrained model from', self.pre_model_path)
self.load_state_dict(pretrained['state_dict'])
for p in self.critic_net.parameters():
p.requires_grad = False

def get_rating_matrix(self, user):
r"""Get a batch of user's feature with the user's id and history interaction matrix.

Args:
user (torch.LongTensor): The input tensor that contains user's id, shape: [batch_size, ]

Returns:
torch.FloatTensor: The user's feature of a batch of user, shape: [batch_size, n_items]
"""
# Following lines construct tensor of shape [B,n_items] using the tensor of shape [B,H]
col_indices = self.history_item_id[user].flatten()
row_indices = torch.arange(user.shape[0]).to(self.device) \
.repeat_interleave(self.history_item_id.shape[1], dim=0)
rating_matrix = torch.zeros(1).to(self.device).repeat(user.shape[0], self.n_items)
rating_matrix.index_put_((row_indices, col_indices), self.history_item_value[user].flatten())
return rating_matrix

def mlp_layers(self, layer_dims):
mlp_modules = []
for i, (d_in, d_out) in enumerate(zip(layer_dims[:-1], layer_dims[1:])):
mlp_modules.append(nn.Linear(d_in, d_out))
if i != len(layer_dims[:-1]) - 1:
mlp_modules.append(nn.Tanh())
return nn.Sequential(*mlp_modules)

def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
epsilon = torch.zeros_like(std).normal_(mean=0, std=0.01)
return mu + epsilon * std
else:
return mu

def forward(self, rating_matrix):

t = F.normalize(rating_matrix)

h = F.dropout(t, self.drop_out, training=self.training) * (1 - self.drop_out)
self.input_matrix = h
self.number_of_seen_items = (h != 0).sum(dim=1) # network input

mask = (h > 0) * (t > 0)
self.true_matrix = t * ~mask
self.number_of_unseen_items = (self.true_matrix != 0).sum(dim=1) # remaining input

h = self.encoder(h)

mu = h[:, :int(self.lat_dim / 2)]
logvar = h[:, int(self.lat_dim / 2):]

z = self.reparameterize(mu, logvar)
z = self.decoder(z)
self.predict_matrix = z
return z, mu, logvar

def calculate_actor_loss(self, interaction):

user = interaction[self.USER_ID]
rating_matrix = self.get_rating_matrix(user)

self.update += 1
if self.total_anneal_steps > 0:
anneal = min(self.anneal_cap, 1. * self.update / self.total_anneal_steps)
else:
anneal = self.anneal_cap

z, mu, logvar = self.forward(rating_matrix)

# KL loss
kl_loss = -0.5 * (torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)) * anneal

# CE loss
ce_loss = -(F.log_softmax(z, 1) * rating_matrix).sum(1)

return ce_loss + kl_loss

def construct_critic_input(self, actor_loss):
critic_inputs = []
critic_inputs.append(self.number_of_seen_items)
critic_inputs.append(self.number_of_unseen_items)
critic_inputs.append(actor_loss)
return torch.stack(critic_inputs, dim=1)

def construct_critic_layers(self):
mlp_modules = []
mlp_modules.append(nn.BatchNorm1d(3))
mlp_modules.append(nn.Linear(3, self.critic_layers_1))
mlp_modules.append(nn.ReLU())
mlp_modules.append(nn.Linear(self.critic_layers_1, self.critic_layers_2))
mlp_modules.append(nn.ReLU())
mlp_modules.append(nn.Linear(self.critic_layers_2, self.critic_layers_3))
mlp_modules.append(nn.ReLU())
mlp_modules.append(nn.Linear(self.critic_layers_3, 1))
mlp_modules.append(nn.Sigmoid())
return nn.Sequential(*mlp_modules)

def calculate_ndcg(self, predict_matrix, true_matrix, input_matrix, k):
users_num = predict_matrix.shape[0]
predict_matrix[input_matrix.nonzero(as_tuple=True)] = -np.inf
_, idx_sorted = torch.sort(predict_matrix, dim=1, descending=True)

topk_result = true_matrix[np.arange(users_num)[:, np.newaxis], idx_sorted[:, :k]]

number_non_zero = ((true_matrix > 0) * 1).sum(dim=1)

tp = 1. / torch.log2(torch.arange(2, k + 2).type(torch.FloatTensor)).to(topk_result.device)
DCG = (topk_result * tp).sum(dim=1)
IDCG = torch.Tensor([(tp[:min(n, k)]).sum() for n in number_non_zero]).to(topk_result.device)
IDCG = torch.maximum(0.1 * torch.ones_like(IDCG).to(IDCG.device), IDCG)

return DCG / IDCG

def critic_forward(self, actor_loss):
h = self.construct_critic_input(actor_loss)
y = self.critic_net(h)
y = torch.squeeze(y)
return y

def calculate_critic_loss(self, interaction):
actor_loss = self.calculate_actor_loss(interaction)
y = self.critic_forward(actor_loss)
score = self.calculate_ndcg(self.predict_matrix, self.true_matrix, self.input_matrix, self.metrics_k)

mse_loss = (y - score) ** 2
return mse_loss

def calculate_ac_loss(self, interaction):
actor_loss = self.calculate_actor_loss(interaction)
y = self.critic_forward(actor_loss)
return -1 * y

def calculate_loss(self, interaction):

# actor_pretrain
if self.train_stage == 'actor_pretrain':
return self.calculate_actor_loss(interaction).mean()
# critic_pretrain
elif self.train_stage == 'critic_pretrain':
return self.calculate_critic_loss(interaction).mean()
# finetune
else:
return self.calculate_ac_loss(interaction).mean()

def predict(self, interaction):

user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

rating_matrix = self.get_rating_matrix(user)

scores, _, _ = self.forward(rating_matrix)

return scores[[user, item]]

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]

rating_matrix = self.get_rating_matrix(user)

scores, _, _ = self.forward(rating_matrix)

return scores.view(-1)
12 changes: 12 additions & 0 deletions recbole/properties/model/RaCT.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
mlp_hidden_size: [600]
latent_dimension: 256
dropout_prob: 0.5
anneal_cap: 0.2
total_anneal_steps: 200000
critic_layers_1: 100
critic_layers_2: 100
critic_layers_3: 10
metrics_k: 100
train_stage: 'finetune'
pretrain_epochs: 50
pre_model_path: './saved/RaCT-lastfm-50.pth'
91 changes: 88 additions & 3 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# @Email : [email protected]

# UPDATE:
# @Time : 2020/8/7, 2020/9/26, 2020/9/26, 2020/10/01, 2020/9/16, 2020/10/8, 2020/10/15, 2020/11/20, 2021/2/20
# @Author : Zihan Lin, Yupeng Hou, Yushuo Chen, Shanlei Mu, Xingyu Pan, Hui Wang, Xinyan Fan, Chen Yang, Yibo Li
# @Email : [email protected], [email protected], [email protected], [email protected], [email protected], [email protected], [email protected], [email protected], [email protected]
# @Time : 2020/8/7, 2020/9/26, 2020/9/26, 2020/10/01, 2020/9/16, 2020/10/8, 2020/10/15, 2020/11/20, 2021/2/16, 2021/2/20
# @Author : Zihan Lin, Yupeng Hou, Yushuo Chen, Shanlei Mu, Xingyu Pan, Hui Wang, Xinyan Fan, Chen Yang, Haoran Cheng, Yibo Li
# @Email : [email protected], [email protected], [email protected], [email protected], [email protected], [email protected], [email protected], [email protected], [email protected], [email protected]

r"""
recbole.trainer.trainer
Expand Down Expand Up @@ -782,6 +782,91 @@ def evaluate(self, eval_data, load_best_model=True, model_file=None, show_progre
result = self.evaluator.evaluate(batch_matrix_list, eval_data)
return result

class RaCTTrainer(Trainer):
r"""RaCTTrainer is designed for RaCT, which is an actor-critic reinforcement learning based general recommenders.
It includes three training stages: actor pre-training, critic pre-training and actor-critic training.

"""

def __init__(self, config, model):
super(RaCTTrainer, self).__init__(config, model)
self.pretrain_epochs = self.config['pretrain_epochs']

def _build_optimizer(self):
r"""Init the Optimizer

Returns:
torch.optim: the optimizer
"""
params = filter(lambda p: p.requires_grad, self.model.parameters())

if self.learner.lower() == 'adam':
optimizer = optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay)
elif self.learner.lower() == 'sgd':
optimizer = optim.SGD(params, lr=self.learning_rate, weight_decay=self.weight_decay)
elif self.learner.lower() == 'adagrad':
optimizer = optim.Adagrad(params, lr=self.learning_rate, weight_decay=self.weight_decay)
elif self.learner.lower() == 'rmsprop':
optimizer = optim.RMSprop(params, lr=self.learning_rate, weight_decay=self.weight_decay)
elif self.learner.lower() == 'sparse_adam':
optimizer = optim.SparseAdam(params, lr=self.learning_rate)
if self.weight_decay > 0:
self.logger.warning('Sparse Adam cannot argument received argument [{weight_decay}]')
else:
self.logger.warning('Received unrecognized optimizer, set default Adam optimizer')
optimizer = optim.Adam(params, lr=self.learning_rate)
return optimizer

def save_pretrained_model(self, epoch, saved_model_file):
r"""Store the model parameters information and training information.

Args:
epoch (int): the current epoch id
saved_model_file (str): file name for saved pretrained model

"""
state = {
'config': self.config,
'epoch': epoch,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
}
torch.save(state, saved_model_file)

def pretrain(self, train_data, verbose=True, show_progress=False):

for epoch_idx in range(self.start_epoch, self.pretrain_epochs):
# train
training_start_time = time()
train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
training_end_time = time()
train_loss_output = \
self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
if verbose:
self.logger.info(train_loss_output)

if (epoch_idx + 1) % self.pretrain_epochs == 0:
saved_model_file = os.path.join(
self.checkpoint_dir,
'{}-{}-{}.pth'.format(self.config['model'], self.config['dataset'], str(epoch_idx + 1))
)
self.save_pretrained_model(epoch_idx, saved_model_file)
update_output = 'Saving current: %s' % saved_model_file
if verbose:
self.logger.info(update_output)

return self.best_valid_score, self.best_valid_result

def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None):
if self.model.train_stage == 'actor_pretrain':
return self.pretrain(train_data, verbose, show_progress)
elif self.model.train_stage == "critic_pretrain":
return self.pretrain(train_data, verbose, show_progress)
elif self.model.train_stage == 'finetune':
return super().fit(train_data, valid_data, verbose, saved, show_progress, callback_fn)
else:
raise ValueError("Please make sure that the 'train_stage' is 'pretrain' or 'finetune' ")

class lightgbmTrainer(DecisionTreeTrainer):
"""lightgbmTrainer is designed for lightgbm.
Expand Down