From b6cb9be7dead27077fe3e567b8df14dd69ed3600 Mon Sep 17 00:00:00 2001 From: guoyihonggyh <793048519@qq.com> Date: Fri, 15 Jan 2021 16:16:38 +0800 Subject: [PATCH 1/5] =?UTF-8?q?README=5FCN.md=E4=B8=AD=E2=80=9C=E9=9F=A9?= =?UTF-8?q?=E6=84=88=E3=80=8A=E9=A9=AC=E8=AF=B4=E3=80=8B=E2=80=9D=E5=8A=A0?= =?UTF-8?q?=E4=B8=80=E7=A9=BA=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README_CN.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README_CN.md b/README_CN.md index 024656bdb..778186a6c 100644 --- a/README_CN.md +++ b/README_CN.md @@ -6,6 +6,7 @@ *“世有伯乐,然后有千里马。千里马常有,而伯乐不常有。”——韩愈《马说》* + [![PyPi Latest Release](https://img.shields.io/pypi/v/recbole)](https://pypi.org/project/recbole/) [![Conda Latest Release](https://anaconda.org/aibox/recbole/badges/version.svg)](https://anaconda.org/aibox/recbole) [![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) From 67c94933802462f6a43d5ffcf13eae65a8847409 Mon Sep 17 00:00:00 2001 From: Yibo-Li-1 <3289438186@qq.com> Date: Fri, 19 Feb 2021 17:37:34 +0800 Subject: [PATCH 2/5] implement RaCT model --- recbole/model/general_recommender/ract.py | 251 ++++++++++++++++++++++ recbole/properties/model/RaCT.yaml | 12 ++ recbole/trainer/trainer.py | 92 +++++++- test.yaml | 29 +++ 4 files changed, 381 insertions(+), 3 deletions(-) create mode 100644 recbole/model/general_recommender/ract.py create mode 100644 recbole/properties/model/RaCT.yaml create mode 100644 test.yaml diff --git a/recbole/model/general_recommender/ract.py b/recbole/model/general_recommender/ract.py new file mode 100644 index 000000000..92b2587c8 --- /dev/null +++ b/recbole/model/general_recommender/ract.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +# @Time : 2021/2/16 +# @Author : Haoran Cheng +# @Email : 1871530482@qq.com + +r""" +RaCT +################################################ +Reference: + + +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from recbole.model.abstract_recommender import GeneralRecommender +from recbole.model.init import xavier_normal_initialization +from recbole.utils import InputType + + +class RaCT(GeneralRecommender): + r"""RaCT is a item-based model collaborative filtering model that simultaneously rank all items for user . + + We implement the RaCT model with only user dataloader. + """ + input_type = InputType.PAIRWISE + + def __init__(self, config, dataset): + super(RaCT, self).__init__(config, dataset) + + self.layers = config["mlp_hidden_size"] + self.lat_dim = config['latent_dimension'] + self.drop_out = config['dropout_prob'] + self.anneal_cap = config['anneal_cap'] + self.total_anneal_steps = config["total_anneal_steps"] + + self.history_item_id, self.history_item_value, _ = dataset.history_item_matrix() + self.history_item_id = self.history_item_id.to(self.device) + self.history_item_value = self.history_item_value.to(self.device) + + self.update = 0 + + self.encode_layer_dims = [self.n_items] + self.layers + [self.lat_dim] + self.decode_layer_dims = [int(self.lat_dim / 2)] + self.encode_layer_dims[::-1][1:] + + self.encoder = self.mlp_layers(self.encode_layer_dims) + self.decoder = self.mlp_layers(self.decode_layer_dims) + + self.critic_layers_1 = config["critic_layers_1"] + self.critic_layers_2 = config["critic_layers_2"] + self.critic_layers_3 = config["critic_layers_3"] + self.metrics_k = config["metrics_k"] + self.number_of_seen_items = 0 + self.number_of_unseen_items = 0 + + self.input_matrix = None + self.predict_matrix = None + self.true_matrix = None + self.critic_net = self.construct_critic_layers() + + self.train_stage = config['train_stage'] + self.pre_model_path = config['pre_model_path'] + + # parameters initialization + assert self.train_stage in ['actor_pretrain', 'critic_pretrain', 'finetune'] + if self.train_stage == 'actor_pretrain': + self.apply(xavier_normal_initialization) + for p in self.critic_net.parameters(): + p.requires_grad = False + elif self.train_stage == 'critic_pretrain': + # load pretrained model for finetune + pretrained = torch.load(self.pre_model_path) + self.logger.info('Load pretrained model from', self.pre_model_path) + self.load_state_dict(pretrained['state_dict']) + for p in self.encoder.parameters(): + p.requires_grad = False + for p in self.decoder.parameters(): + p.requires_grad = False + else: + # load pretrained model for finetune + pretrained = torch.load(self.pre_model_path) + self.logger.info('Load pretrained model from', self.pre_model_path) + self.load_state_dict(pretrained['state_dict']) + for p in self.critic_net.parameters(): + p.requires_grad = False + + def get_rating_matrix(self, user): + r"""Get a batch of user's feature with the user's id and history interaction matrix. + + Args: + user (torch.LongTensor): The input tensor that contains user's id, shape: [batch_size, ] + + Returns: + torch.FloatTensor: The user's feature of a batch of user, shape: [batch_size, n_items] + """ + # Following lines construct tensor of shape [B,n_items] using the tensor of shape [B,H] + col_indices = self.history_item_id[user].flatten() + row_indices = torch.arange(user.shape[0]).to(self.device) \ + .repeat_interleave(self.history_item_id.shape[1], dim=0) + rating_matrix = torch.zeros(1).to(self.device).repeat(user.shape[0], self.n_items) + rating_matrix.index_put_((row_indices, col_indices), self.history_item_value[user].flatten()) + return rating_matrix + + def mlp_layers(self, layer_dims): + mlp_modules = [] + for i, (d_in, d_out) in enumerate(zip(layer_dims[:-1], layer_dims[1:])): + mlp_modules.append(nn.Linear(d_in, d_out)) + if i != len(layer_dims[:-1]) - 1: + mlp_modules.append(nn.Tanh()) + return nn.Sequential(*mlp_modules) + + def reparameterize(self, mu, logvar): + if self.training: + std = torch.exp(0.5 * logvar) + epsilon = torch.zeros_like(std).normal_(mean=0, std=0.01) + return mu + epsilon * std + else: + return mu + + def forward(self, rating_matrix): + + t = F.normalize(rating_matrix) + + h = F.dropout(t, self.drop_out, training=self.training) * (1 - self.drop_out) + self.input_matrix = h + self.number_of_seen_items = (h != 0).sum(dim=1) # network input + + mask = (h > 0) * (t > 0) + self.true_matrix = t * ~mask + self.number_of_unseen_items = (self.true_matrix != 0).sum(dim=1) # remaining input + + h = self.encoder(h) + + mu = h[:, :int(self.lat_dim / 2)] + logvar = h[:, int(self.lat_dim / 2):] + + z = self.reparameterize(mu, logvar) + z = self.decoder(z) + self.predict_matrix = z + return z, mu, logvar + + def calculate_actor_loss(self, interaction): + + user = interaction[self.USER_ID] + rating_matrix = self.get_rating_matrix(user) + + self.update += 1 + if self.total_anneal_steps > 0: + anneal = min(self.anneal_cap, 1. * self.update / self.total_anneal_steps) + else: + anneal = self.anneal_cap + + z, mu, logvar = self.forward(rating_matrix) + + # KL loss + kl_loss = -0.5 * (torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)) * anneal + + # CE loss + ce_loss = -(F.log_softmax(z, 1) * rating_matrix).sum(1) + + return ce_loss + kl_loss + + def construct_critic_input(self, actor_loss): + critic_inputs = [] + critic_inputs.append(self.number_of_seen_items) + critic_inputs.append(self.number_of_unseen_items) + critic_inputs.append(actor_loss) + return torch.stack(critic_inputs, dim=1) + + def construct_critic_layers(self): + mlp_modules = [] + mlp_modules.append(nn.BatchNorm1d(3)) + mlp_modules.append(nn.Linear(3, self.critic_layers_1)) + mlp_modules.append(nn.ReLU()) + mlp_modules.append(nn.Linear(self.critic_layers_1, self.critic_layers_2)) + mlp_modules.append(nn.ReLU()) + mlp_modules.append(nn.Linear(self.critic_layers_2, self.critic_layers_3)) + mlp_modules.append(nn.ReLU()) + mlp_modules.append(nn.Linear(self.critic_layers_3, 1)) + mlp_modules.append(nn.Sigmoid()) + return nn.Sequential(*mlp_modules) + + def calculate_ndcg(self, predict_matrix, true_matrix, input_matrix, k): + users_num = predict_matrix.shape[0] + predict_matrix[input_matrix.nonzero(as_tuple=True)] = -np.inf + _, idx_sorted = torch.sort(predict_matrix, dim=1, descending=True) + + topk_result = true_matrix[np.arange(users_num)[:, np.newaxis], idx_sorted[:, :k]] + + number_non_zero = ((true_matrix > 0) * 1).sum(dim=1) + + tp = 1. / torch.log2(torch.arange(2, k + 2).type(torch.FloatTensor)).to(topk_result.device) + DCG = (topk_result * tp).sum(dim=1) + IDCG = torch.Tensor([(tp[:min(n, k)]).sum() for n in number_non_zero]).to(topk_result.device) + IDCG = torch.maximum(0.1 * torch.ones_like(IDCG).to(IDCG.device), IDCG) + + return DCG / IDCG + + def critic_forward(self, actor_loss): + h = self.construct_critic_input(actor_loss) + y = self.critic_net(h) + y = torch.squeeze(y) + return y + + def calculate_critic_loss(self, interaction): + actor_loss = self.calculate_actor_loss(interaction) + y = self.critic_forward(actor_loss) + score = self.calculate_ndcg(self.predict_matrix, self.true_matrix, self.input_matrix, self.metrics_k) + + mse_loss = (y - score) ** 2 + return mse_loss + + def calculate_ac_loss(self, interaction): + actor_loss = self.calculate_actor_loss(interaction) + y = self.critic_forward(actor_loss) + return -1 * y + + def calculate_loss(self, interaction): + + # actor_pretrain + if self.train_stage == 'actor_pretrain': + return self.calculate_actor_loss(interaction).mean() + # critic_pretrain + elif self.train_stage == 'critic_pretrain': + return self.calculate_critic_loss(interaction).mean() + # finetune + else: + return self.calculate_ac_loss(interaction).mean() + + def predict(self, interaction): + + user = interaction[self.USER_ID] + item = interaction[self.ITEM_ID] + + rating_matrix = self.get_rating_matrix(user) + + scores, _, _ = self.forward(rating_matrix) + + return scores[[user, item]] + + def full_sort_predict(self, interaction): + user = interaction[self.USER_ID] + + rating_matrix = self.get_rating_matrix(user) + + scores, _, _ = self.forward(rating_matrix) + + return scores.view(-1) diff --git a/recbole/properties/model/RaCT.yaml b/recbole/properties/model/RaCT.yaml new file mode 100644 index 000000000..e4b5da8d8 --- /dev/null +++ b/recbole/properties/model/RaCT.yaml @@ -0,0 +1,12 @@ +mlp_hidden_size: [600] +latent_dimension: 200 +dropout_prob: 0.5 +anneal_cap: 0.2 +total_anneal_steps: 200000 +critic_layers_1: 100 +critic_layers_2: 100 +critic_layers_3: 10 +metrics_k: 100 +train_stage: 'finetune' +save_step: 100 +pre_model_path: './saved/RaCT-ml-100k-300.pth' \ No newline at end of file diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index 9d2c3a8ae..d266defa3 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -3,9 +3,9 @@ # @Email : slmu@ruc.edu.cn # UPDATE: -# @Time : 2020/8/7, 2020/9/26, 2020/9/26, 2020/10/01, 2020/9/16, 2020/10/8, 2020/10/15, 2020/11/20 -# @Author : Zihan Lin, Yupeng Hou, Yushuo Chen, Shanlei Mu, Xingyu Pan, Hui Wang, Xinyan Fan, Chen Yang -# @Email : linzihan.super@foxmail.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, slmu@ruc.edu.cn, panxy@ruc.edu.cn, hui.wang@ruc.edu.cn, xinyan.fan@ruc.edu.cn, 254170321@qq.com +# @Time : 2020/8/7, 2020/9/26, 2020/9/26, 2020/10/01, 2020/9/16, 2020/10/8, 2020/10/15, 2020/11/20, 2021/2/16 +# @Author : Zihan Lin, Yupeng Hou, Yushuo Chen, Shanlei Mu, Xingyu Pan, Hui Wang, Xinyan Fan, Chen Yang, Haoran Cheng +# @Email : linzihan.super@foxmail.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, slmu@ruc.edu.cn, panxy@ruc.edu.cn, hui.wang@ruc.edu.cn, xinyan.fan@ruc.edu.cn, 254170321@qq.com, 1871530482@qq.com r""" recbole.trainer.trainer @@ -767,3 +767,89 @@ def evaluate(self, eval_data, load_best_model=True, model_file=None, show_progre batch_matrix_list = [[torch.stack((self.eval_true, self.eval_pred), 1)]] result = self.evaluator.evaluate(batch_matrix_list, eval_data) return result + + +class RaCTTrainer(Trainer): + r"""RaCTTrainer is designed for RaCT, which is an actor-critic reinforcement learning based general recommenders. + It includes three training stages: actor pre-training, critic pre-training and actor-critic training. + + """ + + def __init__(self, config, model): + super(RaCTTrainer, self).__init__(config, model) + + def _build_optimizer(self): + r"""Init the Optimizer + + Returns: + torch.optim: the optimizer + """ + params = filter(lambda p: p.requires_grad, self.model.parameters()) + + if self.learner.lower() == 'adam': + optimizer = optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) + elif self.learner.lower() == 'sgd': + optimizer = optim.SGD(params, lr=self.learning_rate, weight_decay=self.weight_decay) + elif self.learner.lower() == 'adagrad': + optimizer = optim.Adagrad(params, lr=self.learning_rate, weight_decay=self.weight_decay) + elif self.learner.lower() == 'rmsprop': + optimizer = optim.RMSprop(params, lr=self.learning_rate, weight_decay=self.weight_decay) + elif self.learner.lower() == 'sparse_adam': + optimizer = optim.SparseAdam(params, lr=self.learning_rate) + if self.weight_decay > 0: + self.logger.warning('Sparse Adam cannot argument received argument [{weight_decay}]') + else: + self.logger.warning('Received unrecognized optimizer, set default Adam optimizer') + optimizer = optim.Adam(params, lr=self.learning_rate) + return optimizer + + def save_pretrained_model(self, epoch, saved_model_file): + r"""Store the model parameters information and training information. + + Args: + epoch (int): the current epoch id + saved_model_file (str): file name for saved pretrained model + + """ + state = { + 'config': self.config, + 'epoch': epoch, + 'state_dict': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + } + torch.save(state, saved_model_file) + + def pretrain(self, train_data, verbose=True, show_progress=False): + + for epoch_idx in range(self.start_epoch, self.epochs): + # train + training_start_time = time() + train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress) + self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss + training_end_time = time() + train_loss_output = \ + self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) + if verbose: + self.logger.info(train_loss_output) + + if (epoch_idx + 1) % self.config['save_step'] == 0: + saved_model_file = os.path.join( + self.checkpoint_dir, + '{}-{}-{}.pth'.format(self.config['model'], self.config['dataset'], str(epoch_idx + 1)) + ) + self.save_pretrained_model(epoch_idx, saved_model_file) + update_output = 'Saving current: %s' % saved_model_file + if verbose: + self.logger.info(update_output) + + return self.best_valid_score, self.best_valid_result + + def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None): + if self.model.train_stage == 'actor_pretrain': + return self.pretrain(train_data, verbose, show_progress) + elif self.model.train_stage == "critic_pretrain": + return self.pretrain(train_data, verbose, show_progress) + elif self.model.train_stage == 'finetune': + return super().fit(train_data, valid_data, verbose, saved, show_progress, callback_fn) + else: + raise ValueError("Please make sure that the 'train_stage' is 'pretrain' or 'finetune' ") \ No newline at end of file diff --git a/test.yaml b/test.yaml new file mode 100644 index 000000000..098b3e5e6 --- /dev/null +++ b/test.yaml @@ -0,0 +1,29 @@ +# model config +embedding_size: 32 +# dataset config +field_separator: "\t" #指定数据集field的分隔符 +seq_separator: " " #指定数据集中token_seq或者float_seq域里的分隔符 +USER_ID_FIELD: user_id #指定用户id域 +ITEM_ID_FIELD: item_id #指定物品id域 +RATING_FIELD: rating #指定打分rating域 +TIME_FIELD: timestamp #指定时间域 +NEG_PREFIX: neg_ #指定负采样前缀 +#指定从什么文件里读什么列,这里就是从ml-1m.inter里面读取user_id, item_id, rating, timestamp这四列 +load_col: + inter: [user_id, item_id, rating, timestamp] +# training settings +epochs: 500 #训练的最大轮数 +train_batch_size: 2048 #训练的batch_size +learner: adam #使用的pytorch内置优化器 +learning_rate: 0.001 #学习率 +training_neg_sample_num: 0 #负采样数目 +eval_step: 1 #每次训练后做evalaution的次数 +stopping_step: 10 #控制训练收敛的步骤数,在该步骤数内若选取的评测标准没有什么变化,就可以提前停止了 +# evalution settings +eval_setting: RO_RS,full #对数据随机重排,设置按比例划分数据集,且使用全排序 +group_by_user: True #是否将一个user的记录划到一个组里,当eval_setting使用RO_RS的时候该项必须是True +split_ratio: [0.8,0.1,0.1] #切分比例 +metrics: ["Recall", "MRR","NDCG","Hit","Precision"] #评测标准 +topk: [10] #评测标准使用topk,设置成10评测标准就是["Recall@10", "MRR@10", "NDCG@10", "Hit@10", "Precision@10"] +valid_metric: MRR@10 #选取哪个评测标准作为作为提前停止训练的标准 +eval_batch_size: 4096 #评测的batch_size From 30e9d320ff2e9938633771c61d86debe1bc2cd8f Mon Sep 17 00:00:00 2001 From: Yibo-Li-1 <3289438186@qq.com> Date: Mon, 22 Feb 2021 10:27:47 +0800 Subject: [PATCH 3/5] FEA: add RaCT model --- recbole/data/utils.py | 9 +++++---- recbole/model/general_recommender/ract.py | 4 ++-- recbole/properties/model/RaCT.yaml | 6 +++--- recbole/trainer/trainer.py | 7 ++++--- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/recbole/data/utils.py b/recbole/data/utils.py index 3f44d9a7d..ba6eb3eb8 100644 --- a/recbole/data/utils.py +++ b/recbole/data/utils.py @@ -3,9 +3,9 @@ # @Email : houyupeng@ruc.edu.cn # UPDATE: -# @Time : 2020/10/19, 2020/9/17, 2020/8/31 -# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li -# @Email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com +# @Time : 2020/10/19, 2020/9/17, 2020/8/31, 2021/2/20 +# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li, Haoran Cheng +# @Email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com, chenghaoran29@foxmail.com """ recbole.data.utils @@ -237,7 +237,8 @@ def get_data_loader(name, config, eval_setting): "MultiVAE": _get_AE_data_loader, 'MacridVAE': _get_AE_data_loader, 'CDAE': _get_AE_data_loader, - 'ENMF': _get_AE_data_loader + 'ENMF': _get_AE_data_loader, + 'RaCT': _get_AE_data_loader } if config['model'] in register_table: diff --git a/recbole/model/general_recommender/ract.py b/recbole/model/general_recommender/ract.py index 92b2587c8..e508860e3 100644 --- a/recbole/model/general_recommender/ract.py +++ b/recbole/model/general_recommender/ract.py @@ -1,13 +1,13 @@ # -*- coding: utf-8 -*- # @Time : 2021/2/16 # @Author : Haoran Cheng -# @Email : 1871530482@qq.com +# @Email : chenghaoran29@foxmail.com r""" RaCT ################################################ Reference: - + Sam Lobel et al. "RaCT: Towards Amortized Ranking-Critical Training for Collaborative Filtering." in ICLR 2020. """ diff --git a/recbole/properties/model/RaCT.yaml b/recbole/properties/model/RaCT.yaml index e4b5da8d8..290ecd10d 100644 --- a/recbole/properties/model/RaCT.yaml +++ b/recbole/properties/model/RaCT.yaml @@ -1,5 +1,5 @@ mlp_hidden_size: [600] -latent_dimension: 200 +latent_dimension: 256 dropout_prob: 0.5 anneal_cap: 0.2 total_anneal_steps: 200000 @@ -8,5 +8,5 @@ critic_layers_2: 100 critic_layers_3: 10 metrics_k: 100 train_stage: 'finetune' -save_step: 100 -pre_model_path: './saved/RaCT-ml-100k-300.pth' \ No newline at end of file +pretrain_epochs: 50 +pre_model_path: './saved/RaCT-lastfm-50.pth' \ No newline at end of file diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index d266defa3..9af3c63fa 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -5,7 +5,7 @@ # UPDATE: # @Time : 2020/8/7, 2020/9/26, 2020/9/26, 2020/10/01, 2020/9/16, 2020/10/8, 2020/10/15, 2020/11/20, 2021/2/16 # @Author : Zihan Lin, Yupeng Hou, Yushuo Chen, Shanlei Mu, Xingyu Pan, Hui Wang, Xinyan Fan, Chen Yang, Haoran Cheng -# @Email : linzihan.super@foxmail.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, slmu@ruc.edu.cn, panxy@ruc.edu.cn, hui.wang@ruc.edu.cn, xinyan.fan@ruc.edu.cn, 254170321@qq.com, 1871530482@qq.com +# @Email : linzihan.super@foxmail.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, slmu@ruc.edu.cn, panxy@ruc.edu.cn, hui.wang@ruc.edu.cn, xinyan.fan@ruc.edu.cn, 254170321@qq.com, chenghaoran29@foxmail.com r""" recbole.trainer.trainer @@ -777,6 +777,7 @@ class RaCTTrainer(Trainer): def __init__(self, config, model): super(RaCTTrainer, self).__init__(config, model) + self.pretrain_epochs = self.config['pretrain_epochs'] def _build_optimizer(self): r"""Init the Optimizer @@ -821,7 +822,7 @@ def save_pretrained_model(self, epoch, saved_model_file): def pretrain(self, train_data, verbose=True, show_progress=False): - for epoch_idx in range(self.start_epoch, self.epochs): + for epoch_idx in range(self.start_epoch, self.pretrain_epochs): # train training_start_time = time() train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress) @@ -832,7 +833,7 @@ def pretrain(self, train_data, verbose=True, show_progress=False): if verbose: self.logger.info(train_loss_output) - if (epoch_idx + 1) % self.config['save_step'] == 0: + if (epoch_idx + 1) % self.pretrain_epochs == 0: saved_model_file = os.path.join( self.checkpoint_dir, '{}-{}-{}.pth'.format(self.config['model'], self.config['dataset'], str(epoch_idx + 1)) From 483dc824d3997ed690b41a70ea0c2765146cb30d Mon Sep 17 00:00:00 2001 From: Zephyr-29 <47294710+Zephyr-29@users.noreply.github.com> Date: Mon, 22 Feb 2021 13:26:52 +0800 Subject: [PATCH 4/5] Update README_CN.md --- README_CN.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README_CN.md b/README_CN.md index 778186a6c..024656bdb 100644 --- a/README_CN.md +++ b/README_CN.md @@ -6,7 +6,6 @@ *“世有伯乐,然后有千里马。千里马常有,而伯乐不常有。”——韩愈《马说》* - [![PyPi Latest Release](https://img.shields.io/pypi/v/recbole)](https://pypi.org/project/recbole/) [![Conda Latest Release](https://anaconda.org/aibox/recbole/badges/version.svg)](https://anaconda.org/aibox/recbole) [![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) From aa22eea15ecc1202b1bb0b4646a522da1ba10be9 Mon Sep 17 00:00:00 2001 From: Zephyr-29 <47294710+Zephyr-29@users.noreply.github.com> Date: Mon, 22 Feb 2021 13:27:16 +0800 Subject: [PATCH 5/5] Delete test.yaml --- test.yaml | 29 ----------------------------- 1 file changed, 29 deletions(-) delete mode 100644 test.yaml diff --git a/test.yaml b/test.yaml deleted file mode 100644 index 098b3e5e6..000000000 --- a/test.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# model config -embedding_size: 32 -# dataset config -field_separator: "\t" #指定数据集field的分隔符 -seq_separator: " " #指定数据集中token_seq或者float_seq域里的分隔符 -USER_ID_FIELD: user_id #指定用户id域 -ITEM_ID_FIELD: item_id #指定物品id域 -RATING_FIELD: rating #指定打分rating域 -TIME_FIELD: timestamp #指定时间域 -NEG_PREFIX: neg_ #指定负采样前缀 -#指定从什么文件里读什么列,这里就是从ml-1m.inter里面读取user_id, item_id, rating, timestamp这四列 -load_col: - inter: [user_id, item_id, rating, timestamp] -# training settings -epochs: 500 #训练的最大轮数 -train_batch_size: 2048 #训练的batch_size -learner: adam #使用的pytorch内置优化器 -learning_rate: 0.001 #学习率 -training_neg_sample_num: 0 #负采样数目 -eval_step: 1 #每次训练后做evalaution的次数 -stopping_step: 10 #控制训练收敛的步骤数,在该步骤数内若选取的评测标准没有什么变化,就可以提前停止了 -# evalution settings -eval_setting: RO_RS,full #对数据随机重排,设置按比例划分数据集,且使用全排序 -group_by_user: True #是否将一个user的记录划到一个组里,当eval_setting使用RO_RS的时候该项必须是True -split_ratio: [0.8,0.1,0.1] #切分比例 -metrics: ["Recall", "MRR","NDCG","Hit","Precision"] #评测标准 -topk: [10] #评测标准使用topk,设置成10评测标准就是["Recall@10", "MRR@10", "NDCG@10", "Hit@10", "Precision@10"] -valid_metric: MRR@10 #选取哪个评测标准作为作为提前停止训练的标准 -eval_batch_size: 4096 #评测的batch_size