diff --git a/docs/source/user_guide/config/training_settings.rst b/docs/source/user_guide/config/training_settings.rst index 1d4ff583f..ae9769158 100644 --- a/docs/source/user_guide/config/training_settings.rst +++ b/docs/source/user_guide/config/training_settings.rst @@ -28,3 +28,5 @@ Training settings are designed to set parameters about model training. - ``loss_decimal_place(int)``: The decimal place of training loss. Defaults to ``4``. - ``weight_decay (float)`` : The weight decay (L2 penalty), used for `optimizer `_. Default to ``0.0``. - ``require_pow(bool)``: The sign identifies whether the power operation is performed based on the norm in EmbLoss. Defaults to ``False``. +- ``enable_amp(bool)``: The parameter determines whether to use mixed precision training . Defaults to ``False``. +- ``enable_scaler(bool)``: The parameter determines whether to use GradScaler that is often used with mixed precision training to avoid gradient precision overflow. Defaults to ``False``. \ No newline at end of file diff --git a/recbole/model/context_aware_recommender/afm.py b/recbole/model/context_aware_recommender/afm.py index d470a9769..e7257caa0 100644 --- a/recbole/model/context_aware_recommender/afm.py +++ b/recbole/model/context_aware_recommender/afm.py @@ -39,7 +39,7 @@ def __init__(self, config, dataset): self.p = nn.Parameter(torch.randn(self.embedding_size), requires_grad=True) self.dropout_layer = nn.Dropout(p=self.dropout_prob) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -101,7 +101,7 @@ def afm_layer(self, infeature): def forward(self, interaction): afm_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim] - output = self.sigmoid(self.first_order_linear(interaction) + self.afm_layer(afm_all_embeddings)) + output = self.first_order_linear(interaction) + self.afm_layer(afm_all_embeddings) return output.squeeze(-1) def calculate_loss(self, interaction): @@ -112,4 +112,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) + l2_loss def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/autoint.py b/recbole/model/context_aware_recommender/autoint.py index 893d5d8f2..8276e1e5c 100644 --- a/recbole/model/context_aware_recommender/autoint.py +++ b/recbole/model/context_aware_recommender/autoint.py @@ -56,7 +56,7 @@ def __init__(self, config, dataset): self.dropout_layer = nn.Dropout(p=self.dropout_probs[2]) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -97,7 +97,7 @@ def autoint_layer(self, infeature): def forward(self, interaction): autoint_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim] output = self.first_order_linear(interaction) + self.autoint_layer(autoint_all_embeddings) - return self.sigmoid(output.squeeze(1)) + return output.squeeze(1) def calculate_loss(self, interaction): label = interaction[self.LABEL] @@ -105,4 +105,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/dcn.py b/recbole/model/context_aware_recommender/dcn.py index 0edfc6ad8..b3354a235 100644 --- a/recbole/model/context_aware_recommender/dcn.py +++ b/recbole/model/context_aware_recommender/dcn.py @@ -62,7 +62,7 @@ def __init__(self, config, dataset): self.predict_layer = nn.Linear(in_feature_num, 1) self.reg_loss = RegLoss() self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -108,7 +108,7 @@ def forward(self, interaction): # Cross Network cross_output = self.cross_network(dcn_all_embeddings) stack = torch.cat([cross_output, deep_output], dim=-1) - output = self.sigmoid(self.predict_layer(stack)) + output = self.predict_layer(stack) return output.squeeze(1) @@ -119,4 +119,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) + l2_loss def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/deepfm.py b/recbole/model/context_aware_recommender/deepfm.py index 6c7c47053..34dd3977d 100644 --- a/recbole/model/context_aware_recommender/deepfm.py +++ b/recbole/model/context_aware_recommender/deepfm.py @@ -42,7 +42,7 @@ def __init__(self, config, dataset): self.mlp_layers = MLPLayers(size_list, self.dropout_prob) self.deep_predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1) # Linear product to the final score self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -61,7 +61,7 @@ def forward(self, interaction): y_fm = self.first_order_linear(interaction) + self.fm(deepfm_all_embeddings) y_deep = self.deep_predict_layer(self.mlp_layers(deepfm_all_embeddings.view(batch_size, -1))) - y = self.sigmoid(y_fm + y_deep) + y = y_fm + y_deep return y.squeeze(-1) def calculate_loss(self, interaction): @@ -70,4 +70,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/dssm.py b/recbole/model/context_aware_recommender/dssm.py index d4b6b67e4..bc382e6cb 100644 --- a/recbole/model/context_aware_recommender/dssm.py +++ b/recbole/model/context_aware_recommender/dssm.py @@ -41,7 +41,7 @@ def __init__(self, config, dataset): self.user_mlp_layers = MLPLayers(user_size_list, self.dropout_prob, activation='tanh', bn=True) self.item_mlp_layers = MLPLayers(item_size_list, self.dropout_prob, activation='tanh', bn=True) - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() self.sigmoid = nn.Sigmoid() # parameters initialization @@ -84,9 +84,7 @@ def forward(self, interaction): user_dnn_out = self.user_mlp_layers(embed_user.view(batch_size, -1)) item_dnn_out = self.item_mlp_layers(embed_item.view(batch_size, -1)) score = torch.cosine_similarity(user_dnn_out, item_dnn_out, dim=1) - - sig_score = self.sigmoid(score) - return sig_score.squeeze(-1) + return score.squeeze(-1) def calculate_loss(self, interaction): label = interaction[self.LABEL] @@ -94,4 +92,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/ffm.py b/recbole/model/context_aware_recommender/ffm.py index c16e8e90e..3e334e3e4 100644 --- a/recbole/model/context_aware_recommender/ffm.py +++ b/recbole/model/context_aware_recommender/ffm.py @@ -53,7 +53,7 @@ def __init__(self, config, dataset): self.feature_names, self.feature_dims, self.feature2id, self.feature2field, self.num_fields, self.embedding_size, self.device ) - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -116,7 +116,7 @@ def get_ffm_input(self, interaction): def forward(self, interaction): ffm_input = self.get_ffm_input(interaction) ffm_output = torch.sum(torch.sum(self.ffm(ffm_input), dim=1), dim=1, keepdim=True) - output = self.sigmoid(self.first_order_linear(interaction) + ffm_output) + output = self.first_order_linear(interaction) + ffm_output return output.squeeze(-1) @@ -127,7 +127,7 @@ def calculate_loss(self, interaction): return self.loss(output, label) def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) class FieldAwareFactorizationMachine(nn.Module): diff --git a/recbole/model/context_aware_recommender/fm.py b/recbole/model/context_aware_recommender/fm.py index e8a31b2c9..63f712aec 100644 --- a/recbole/model/context_aware_recommender/fm.py +++ b/recbole/model/context_aware_recommender/fm.py @@ -35,7 +35,7 @@ def __init__(self, config, dataset): # define layers and loss self.fm = BaseFactorizationMachine(reduce_sum=True) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -46,7 +46,7 @@ def _init_weights(self, module): def forward(self, interaction): fm_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim] - y = self.sigmoid(self.first_order_linear(interaction) + self.fm(fm_all_embeddings)) + y = self.first_order_linear(interaction) + self.fm(fm_all_embeddings) return y.squeeze(-1) def calculate_loss(self, interaction): @@ -56,4 +56,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/fnn.py b/recbole/model/context_aware_recommender/fnn.py index 51de84860..f7334d127 100644 --- a/recbole/model/context_aware_recommender/fnn.py +++ b/recbole/model/context_aware_recommender/fnn.py @@ -44,7 +44,7 @@ def __init__(self, config, dataset): self.predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1, bias=True) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -62,7 +62,6 @@ def forward(self, interaction): batch_size = fnn_all_embeddings.shape[0] output = self.predict_layer(self.mlp_layers(fnn_all_embeddings.view(batch_size, -1))) - output = self.sigmoid(output) return output.squeeze(-1) def calculate_loss(self, interaction): @@ -72,4 +71,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/fwfm.py b/recbole/model/context_aware_recommender/fwfm.py index deccb63df..d3db3988b 100644 --- a/recbole/model/context_aware_recommender/fwfm.py +++ b/recbole/model/context_aware_recommender/fwfm.py @@ -51,7 +51,7 @@ def __init__(self, config, dataset): self.num_fields = len(set(self.feature2field.values())) # the number of fields self.num_pair = self.num_fields * self.num_fields - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -122,7 +122,7 @@ def fwfm_layer(self, infeature): def forward(self, interaction): fwfm_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim] - output = self.sigmoid(self.first_order_linear(interaction) + self.fwfm_layer(fwfm_all_embeddings)) + output = self.first_order_linear(interaction) + self.fwfm_layer(fwfm_all_embeddings) return output.squeeze(-1) @@ -133,4 +133,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/lr.py b/recbole/model/context_aware_recommender/lr.py index 45d2c8df1..1e5a802d8 100644 --- a/recbole/model/context_aware_recommender/lr.py +++ b/recbole/model/context_aware_recommender/lr.py @@ -32,7 +32,7 @@ def __init__(self, config, dataset): super(LR, self).__init__(config, dataset) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -42,7 +42,7 @@ def _init_weights(self, module): xavier_normal_(module.weight.data) def forward(self, interaction): - output = self.sigmoid(self.first_order_linear(interaction)) + output = self.first_order_linear(interaction) return output.squeeze(-1) def calculate_loss(self, interaction): @@ -52,4 +52,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/nfm.py b/recbole/model/context_aware_recommender/nfm.py index 4548fb667..d9f47a84c 100644 --- a/recbole/model/context_aware_recommender/nfm.py +++ b/recbole/model/context_aware_recommender/nfm.py @@ -37,7 +37,7 @@ def __init__(self, config, dataset): self.mlp_layers = MLPLayers(size_list, self.dropout_prob, activation='sigmoid', bn=True) self.predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1, bias=False) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -55,7 +55,6 @@ def forward(self, interaction): bn_nfm_all_embeddings = self.bn(self.fm(nfm_all_embeddings)) output = self.predict_layer(self.mlp_layers(bn_nfm_all_embeddings)) + self.first_order_linear(interaction) - output = self.sigmoid(output) return output.squeeze(-1) def calculate_loss(self, interaction): @@ -64,4 +63,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/pnn.py b/recbole/model/context_aware_recommender/pnn.py index d4c5f4fbb..5423edb66 100644 --- a/recbole/model/context_aware_recommender/pnn.py +++ b/recbole/model/context_aware_recommender/pnn.py @@ -56,7 +56,7 @@ def __init__(self, config, dataset): self.predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -98,7 +98,6 @@ def forward(self, interaction): output = torch.cat(output, dim=1) # [batch_size,d] output = self.predict_layer(self.mlp_layers(output)) # [batch_size,1] - output = self.sigmoid(output) return output.squeeze(-1) def calculate_loss(self, interaction): @@ -108,7 +107,7 @@ def calculate_loss(self, interaction): return self.loss(output, label) + self.reg_loss() def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) class InnerProductLayer(nn.Module): diff --git a/recbole/model/context_aware_recommender/widedeep.py b/recbole/model/context_aware_recommender/widedeep.py index 6b2fda950..ea3658389 100644 --- a/recbole/model/context_aware_recommender/widedeep.py +++ b/recbole/model/context_aware_recommender/widedeep.py @@ -39,7 +39,7 @@ def __init__(self, config, dataset): self.mlp_layers = MLPLayers(size_list, self.dropout_prob) self.deep_predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -58,7 +58,7 @@ def forward(self, interaction): fm_output = self.first_order_linear(interaction) deep_output = self.deep_predict_layer(self.mlp_layers(widedeep_all_embeddings.view(batch_size, -1))) - output = self.sigmoid(fm_output + deep_output) + output = fm_output + deep_output return output.squeeze(-1) def calculate_loss(self, interaction): @@ -67,4 +67,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/context_aware_recommender/xdeepfm.py b/recbole/model/context_aware_recommender/xdeepfm.py index 0e5085788..a4cabc2e4 100644 --- a/recbole/model/context_aware_recommender/xdeepfm.py +++ b/recbole/model/context_aware_recommender/xdeepfm.py @@ -76,7 +76,7 @@ def __init__(self, config, dataset): self.cin_linear = nn.Linear(self.final_len, 1) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() self.apply(self._init_weights) def _init_weights(self, module): @@ -180,9 +180,9 @@ def forward(self, interaction): # Get predicted score. y_p = self.first_order_linear(interaction) + cin_output + dnn_output - y = self.sigmoid(y_p) - return y.squeeze(1) + + return y_p.squeeze(1) def calculate_loss(self, interaction): label = interaction[self.LABEL] @@ -191,4 +191,4 @@ def calculate_loss(self, interaction): return self.loss(output, label) + self.reg_weight * l2_reg def predict(self, interaction): - return self.forward(interaction) + return self.sigmoid(self.forward(interaction)) diff --git a/recbole/model/general_recommender/cdae.py b/recbole/model/general_recommender/cdae.py index 93c8791aa..938498839 100644 --- a/recbole/model/general_recommender/cdae.py +++ b/recbole/model/general_recommender/cdae.py @@ -75,7 +75,7 @@ def forward(self, x_items, x_users): h = torch.add(h_u, h_i) h = self.h_act(h) out = self.out_layer(h) - return self.o_act(out) + return out def get_rating_matrix(self, user): r"""Get a batch of user's feature with the user's id and history interaction matrix. @@ -100,12 +100,12 @@ def calculate_loss(self, interaction): predict = self.forward(x_items, x_users) if self.loss_type == 'MSE': + predict=self.o_act(predict) loss_func = nn.MSELoss(reduction='sum') elif self.loss_type == 'BCE': - loss_func = nn.BCELoss(reduction='sum') + loss_func = nn.BCEWithLogitsLoss(reduction='sum') else: raise ValueError('Invalid loss_type, loss_type must in [MSE, BCE]') - loss = loss_func(predict, x_items) # l1-regularization loss += self.reg_weight_1 * (self.h_user.weight.norm(p=1) + self.h_item.weight.norm(p=1)) @@ -120,7 +120,7 @@ def predict(self, interaction): items = self.get_rating_matrix(users) scores = self.forward(items, users) - + scores=self.o_act(scores) return scores[[torch.arange(len(predict_items)).to(self.device), predict_items]] def full_sort_predict(self, interaction): @@ -128,4 +128,5 @@ def full_sort_predict(self, interaction): items = self.get_rating_matrix(users) predict = self.forward(items, users) + predict=self.o_act(predict) return predict.view(-1) diff --git a/recbole/model/general_recommender/dmf.py b/recbole/model/general_recommender/dmf.py index c74836d01..049b199d9 100644 --- a/recbole/model/general_recommender/dmf.py +++ b/recbole/model/general_recommender/dmf.py @@ -79,7 +79,7 @@ def __init__(self, config, dataset): self.user_fc_layers = MLPLayers([self.user_embedding_size] + self.user_hidden_size_list) self.item_fc_layers = MLPLayers([self.item_embedding_size] + self.item_hidden_size_list) self.sigmoid = nn.Sigmoid() - self.bce_loss = nn.BCELoss() + self.bce_loss = nn.BCEWithLogitsLoss() # Save the item embedding before dot product layer to speed up evaluation self.i_embedding = None @@ -114,7 +114,6 @@ def forward(self, user, item): # cosine distance is replaced by dot product according the result of our experiments. vector = torch.mul(user, item).sum(dim=1) - vector = self.sigmoid(vector) return vector @@ -130,7 +129,7 @@ def calculate_loss(self, interaction): elif self.inter_matrix_type == 'rating': label = interaction[self.RATING] * interaction[self.LABEL] output = self.forward(user, item) - + label = label / self.max_rating # normalize the label to calculate BCE loss. loss = self.bce_loss(output, label) return loss @@ -138,7 +137,8 @@ def calculate_loss(self, interaction): def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] - return self.forward(user, item) + predict=self.sigmoid(self.forward(user, item)) + return predict def get_user_embedding(self, user): r"""Get a batch of user's embedding with the user's id and history interaction matrix. diff --git a/recbole/model/general_recommender/fism.py b/recbole/model/general_recommender/fism.py index fdfecc216..9199d0bdc 100644 --- a/recbole/model/general_recommender/fism.py +++ b/recbole/model/general_recommender/fism.py @@ -59,7 +59,7 @@ def __init__(self, config, dataset): self.item_dst_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) self.user_bias = nn.Parameter(torch.zeros(self.n_users)) self.item_bias = nn.Parameter(torch.zeros(self.n_items)) - self.bceloss = nn.BCELoss() + self.bceloss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -147,7 +147,7 @@ def user_forward(self, user_input, item_num, user_bias, repeats=None, pred_slc=N item_bias = self.item_bias[pred_slc] similarity = torch.bmm(user_history, targets.unsqueeze(2)).squeeze(2) # inter_num x target_items coeff = torch.pow(item_num.squeeze(1), -self.alpha) - scores = torch.sigmoid(coeff.float() * torch.sum(similarity, dim=1) + user_bias + item_bias) + scores = coeff.float() * torch.sum(similarity, dim=1) + user_bias + item_bias return scores def forward(self, user, item): @@ -187,5 +187,5 @@ def full_sort_predict(self, interaction): def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] - output = self.forward(user, item) + output = torch.sigmoid(self.forward(user, item)) return output diff --git a/recbole/model/general_recommender/nais.py b/recbole/model/general_recommender/nais.py index c59cf594b..541e627b3 100644 --- a/recbole/model/general_recommender/nais.py +++ b/recbole/model/general_recommender/nais.py @@ -80,7 +80,7 @@ def __init__(self, config, dataset): else: raise ValueError("NAIS just support attention type in ['concat', 'prod'] but get {}".format(self.algorithm)) self.weight_layer = nn.Parameter(torch.ones(self.weight_size, 1)) - self.bceloss = nn.BCELoss() + self.bceloss = nn.BCEWithLogitsLoss() # parameters initialization if self.pretrain_path is not None: @@ -194,7 +194,7 @@ def mask_softmax(self, similarity, logits, bias, item_num, batch_mask_mat): weights = torch.div(exp_logits, exp_sum) coeff = torch.pow(item_num.squeeze(1), -self.alpha) - output = torch.sigmoid(coeff.float() * torch.sum(weights * similarity, dim=1) + bias) + output =coeff.float() * torch.sum(weights * similarity, dim=1) + bias return output @@ -297,5 +297,5 @@ def full_sort_predict(self, interaction): def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] - output = self.forward(user, item) + output = torch.sigmoid(self.forward(user, item)) return output diff --git a/recbole/model/general_recommender/neumf.py b/recbole/model/general_recommender/neumf.py index d595c1b21..1bd105050 100644 --- a/recbole/model/general_recommender/neumf.py +++ b/recbole/model/general_recommender/neumf.py @@ -66,7 +66,7 @@ def __init__(self, config, dataset): elif self.mlp_train: self.predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization if self.use_pretrain: @@ -110,11 +110,11 @@ def forward(self, user, item): if self.mlp_train: mlp_output = self.mlp_layers(torch.cat((user_mlp_e, item_mlp_e), -1)) # [batch_size, layers[-1]] if self.mf_train and self.mlp_train: - output = self.sigmoid(self.predict_layer(torch.cat((mf_output, mlp_output), -1))) + output = self.predict_layer(torch.cat((mf_output, mlp_output), -1)) elif self.mf_train: - output = self.sigmoid(self.predict_layer(mf_output)) + output = self.predict_layer(mf_output) elif self.mlp_train: - output = self.sigmoid(self.predict_layer(mlp_output)) + output = self.predict_layer(mlp_output) else: raise RuntimeError('mf_train and mlp_train can not be False at the same time') return output.squeeze(-1) @@ -123,14 +123,15 @@ def calculate_loss(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] label = interaction[self.LABEL] - - output = self.forward(user, item) + + output = self.forward(user, item) return self.loss(output, label) def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] - return self.forward(user, item) + predict=self.sigmoid(self.forward(user, item)) + return predict def dump_parameters(self): r"""A simple implementation of dumping model parameters for pretrain. diff --git a/recbole/model/general_recommender/nncf.py b/recbole/model/general_recommender/nncf.py index 1860cfc28..52bcb8def 100644 --- a/recbole/model/general_recommender/nncf.py +++ b/recbole/model/general_recommender/nncf.py @@ -69,10 +69,10 @@ def __init__(self, config, dataset): [2 * pooled_size * self.num_conv_kernel + self.ui_embedding_size] + self.mlp_hidden_size, config['dropout'] ) - self.out_layer = nn.Sequential(nn.Linear(self.mlp_hidden_size[-1], 1), nn.Sigmoid()) + self.out_layer = nn.Linear(self.mlp_hidden_size[-1], 1) self.dropout_layer = torch.nn.Dropout(p=config['dropout']) - self.loss = nn.BCELoss() - + self.loss = nn.BCEWithLogitsLoss() + # choose the method to use neighborhood information if self.neigh_info_method == "random": self.u_neigh, self.i_neigh = self.get_neigh_random() @@ -353,10 +353,10 @@ def calculate_loss(self, interaction): item = interaction[self.ITEM_ID] label = interaction[self.LABEL] - output = self.forward(user, item) - return self.loss(output, label) + output = self.forward(user, item) + return self.loss(output, label) def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] - return self.forward(user, item) + return torch.sigmoid(self.forward(user, item)) diff --git a/recbole/model/sequential_recommender/dien.py b/recbole/model/sequential_recommender/dien.py index e7f53e066..28e1d86df 100644 --- a/recbole/model/sequential_recommender/dien.py +++ b/recbole/model/sequential_recommender/dien.py @@ -83,7 +83,7 @@ def __init__(self, config, dataset): self.dnn_mlp_layers = MLPLayers(self.dnn_mlp_list, activation='Dice', dropout=self.dropout_prob, bn=True) self.dnn_predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() self.apply(self._init_weights) self.other_parameter_name = ['embedding_layer'] @@ -130,8 +130,6 @@ def forward(self, user, item_seq, neg_item_seq, item_seq_len, next_items): # input the DNN to get the prediction score dien_out = self.dnn_mlp_layers(dien_in) preds = self.dnn_predict_layer(dien_out) - preds = self.sigmoid(preds) - return preds.squeeze(1), aux_loss def calculate_loss(self, interaction): @@ -152,7 +150,7 @@ def predict(self, interaction): item_seq_len = interaction[self.ITEM_SEQ_LEN] next_items = interaction[self.POS_ITEM_ID] scores, _ = self.forward(user, item_seq, neg_item_seq, item_seq_len, next_items) - return scores + return self.sigmoid(scores) class InterestExtractorNetwork(nn.Module): @@ -164,7 +162,7 @@ class InterestExtractorNetwork(nn.Module): def __init__(self, input_size, hidden_size, mlp_size): super(InterestExtractorNetwork, self).__init__() self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, batch_first=True) - self.auxiliary_net = MLPLayers(layers=mlp_size, activation='Sigmoid') + self.auxiliary_net = MLPLayers(layers=mlp_size, activation='none') def forward(self, keys, keys_length, neg_keys=None): batch_size, hist_len, embedding_size = keys.shape @@ -208,7 +206,7 @@ def auxiliary_loss(self, h_states, click_seq, noclick_seq, keys_length): # non-click label noclick_target = torch.zeros(noclick_prop.shape, device=noclick_input.device) - loss = F.binary_cross_entropy( + loss = F.binary_cross_entropy_with_logits( torch.cat([click_prop, noclick_prop], dim=0), torch.cat([click_target, noclick_target], dim=0) ) @@ -424,5 +422,5 @@ def forward(self, input, att_scores=None, hidden_output=None): outputs[begin:begin + batch] = new_hx hidden_output = new_hx begin += batch - + return PackedSequence(outputs, batch_sizes, sorted_indices, unsorted_indices) diff --git a/recbole/model/sequential_recommender/din.py b/recbole/model/sequential_recommender/din.py index 101a2b4d4..47aa1ff51 100644 --- a/recbole/model/sequential_recommender/din.py +++ b/recbole/model/sequential_recommender/din.py @@ -76,7 +76,7 @@ def __init__(self, config, dataset): self.embedding_layer = ContextSeqEmbLayer(dataset, self.embedding_size, self.pooling_mode, self.device) self.dnn_predict_layers = nn.Linear(self.mlp_hidden_size[-1], 1) self.sigmoid = nn.Sigmoid() - self.loss = nn.BCELoss() + self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) @@ -122,7 +122,6 @@ def forward(self, user, item_seq, item_seq_len, next_items): din_in = torch.cat([user_emb, target_item_feat_emb, user_emb * target_item_feat_emb], dim=-1) din_out = self.dnn_mlp_layers(din_in) preds = self.dnn_predict_layers(din_out) - preds = self.sigmoid(preds) return preds.squeeze(1) @@ -141,5 +140,5 @@ def predict(self, interaction): user = interaction[self.USER_ID] item_seq_len = interaction[self.ITEM_SEQ_LEN] next_items = interaction[self.POS_ITEM_ID] - scores = self.forward(user, item_seq, item_seq_len, next_items) + scores = self.sigmoid(self.forward(user, item_seq, item_seq_len, next_items)) return scores diff --git a/recbole/model/sequential_recommender/s3rec.py b/recbole/model/sequential_recommender/s3rec.py index d61cccca5..9b2dd8168 100644 --- a/recbole/model/sequential_recommender/s3rec.py +++ b/recbole/model/sequential_recommender/s3rec.py @@ -95,7 +95,7 @@ def __init__(self, config, dataset): self.mip_norm = nn.Linear(self.hidden_size, self.hidden_size) self.map_norm = nn.Linear(self.hidden_size, self.hidden_size) self.sp_norm = nn.Linear(self.hidden_size, self.hidden_size) - self.loss_fct = nn.BCELoss(reduction='none') + self.loss_fct = nn.BCEWithLogitsLoss(reduction='none') # modules for finetune if self.loss_type == 'BPR' and self.train_stage == 'finetune': @@ -132,7 +132,7 @@ def _associated_attribute_prediction(self, sequence_output, feature_embedding): sequence_output = sequence_output.view([-1, sequence_output.size(-1), 1]) # [B*L H 1] # [feature_num H] [B*L H 1] -> [B*L feature_num 1] score = torch.matmul(feature_embedding, sequence_output) - return torch.sigmoid(score.squeeze(-1)) # [B*L feature_num] + return score.squeeze(-1) # [B*L feature_num] def _masked_item_prediction(self, sequence_output, target_item_emb): sequence_output = self.mip_norm(sequence_output.view([-1, sequence_output.size(-1)])) # [B*L H] @@ -145,7 +145,7 @@ def _masked_attribute_prediction(self, sequence_output, feature_embedding): sequence_output = sequence_output.view([-1, sequence_output.size(-1), 1]) # [B*L H 1] # [feature_num H] [B*L H 1] -> [B*L feature_num 1] score = torch.matmul(feature_embedding, sequence_output) - return torch.sigmoid(score.squeeze(-1)) # [B*L feature_num] + return score.squeeze(-1) # [B*L feature_num] def _segment_prediction(self, context, segment_emb): context = self.sp_norm(context) @@ -196,7 +196,7 @@ def pretrain( neg_item_embs = self.item_embedding(neg_items) pos_score = self._masked_item_prediction(sequence_output, pos_item_embs) neg_score = self._masked_item_prediction(sequence_output, neg_item_embs) - mip_distance = torch.sigmoid(pos_score - neg_score) + mip_distance =pos_score - neg_score mip_loss = self.loss_fct(mip_distance, torch.ones_like(mip_distance, dtype=torch.float32)) mip_mask = (masked_item_sequence == self.mask_token).float() mip_loss = torch.sum(mip_loss * mip_mask.flatten()) @@ -215,7 +215,7 @@ def pretrain( neg_segment_emb = self.forward(neg_segment)[:, -1, :] # [B H] pos_segment_score = self._segment_prediction(segment_context, pos_segment_emb) neg_segment_score = self._segment_prediction(segment_context, neg_segment_emb) - sp_distance = torch.sigmoid(pos_segment_score - neg_segment_score) + sp_distance = pos_segment_score - neg_segment_score sp_loss = torch.sum(self.loss_fct(sp_distance, torch.ones_like(sp_distance, dtype=torch.float32))) pretrain_loss = self.aap_weight * aap_loss \ diff --git a/recbole/properties/overall.yaml b/recbole/properties/overall.yaml index 4951d1fbf..98d966613 100644 --- a/recbole/properties/overall.yaml +++ b/recbole/properties/overall.yaml @@ -29,6 +29,8 @@ clip_grad_norm: ~ weight_decay: 0.0 loss_decimal_place: 4 require_pow: False +enable_amp: False +enable_scaler: False shuffle: True # evaluation settings diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index 75e1abb8b..8e684dcd0 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -30,6 +30,7 @@ import torch.optim as optim from torch.nn.utils.clip_grad import clip_grad_norm_ from tqdm import tqdm +import torch.cuda.amp as amp from recbole.data.interaction import Interaction from recbole.data.dataloader import FullSortEvalDataLoader @@ -38,6 +39,7 @@ EvaluatorType, KGDataLoaderState, get_tensorboard, set_color, get_gpu_usage, WandbLogger from torch.nn.parallel import DistributedDataParallel + class AbstractTrainer(object): r"""Trainer Class is used to manage the training and evaluation processes of recommender system models. AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according @@ -117,6 +119,8 @@ def __init__(self, config, model): self.gpu_available = torch.cuda.is_available() and config['use_gpu'] self.device = config['device'] self.checkpoint_dir = config['checkpoint_dir'] + self.enable_amp=config['enable_amp'] + self.enable_scaler=config['enable_scaler'] ensure_dir(self.checkpoint_dir) saved_model_file = '{}-{}.pth'.format(self.config['model'], get_local_time()) self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file) @@ -204,7 +208,7 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals if not self.config['single_spec'] and train_data.shuffle: train_data.sampler.set_epoch(epoch_idx) - + scaler = amp.GradScaler(enabled=self.enable_scaler) for batch_idx, interaction in enumerate(iter_data): interaction = interaction.to(self.device) self.optimizer.zero_grad() @@ -212,7 +216,10 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals if not self.config['single_spec']: self.set_reduce_hook() sync_loss = self.sync_grad_loss() - losses = loss_func(interaction) + + with amp.autocast(enabled=self.enable_amp): + losses = loss_func(interaction) + if isinstance(losses, tuple): loss = sum(losses) loss_tuple = tuple(per_loss.item() for per_loss in losses) @@ -221,10 +228,11 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals loss = losses total_loss = losses.item() if total_loss is None else total_loss + losses.item() self._check_nan(loss) - (loss + sync_loss).backward() + scaler.scale(loss + sync_loss).backward() if self.clip_grad_norm: clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm) - self.optimizer.step() + scaler.step(self.optimizer) + scaler.update() if self.gpu_available and show_progress: iter_data.set_postfix_str(set_color('GPU RAM: ' + get_gpu_usage(self.device), 'yellow')) return total_loss @@ -1222,6 +1230,7 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals desc=set_color(f"Train {epoch_idx:>5}", 'pink'), ) if show_progress else train_data ) + scaler = amp.GradScaler(enabled=self.enable_scaler) if not self.config['single_spec'] and train_data.shuffle: train_data.sampler.set_epoch(epoch_idx) @@ -1233,7 +1242,10 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals if not self.config['single_spec']: self.set_reduce_hook() sync_loss = self.sync_grad_loss() - losses = loss_func(interaction) + + with amp.autocast(enabled=self.enable_amp): + losses = loss_func(interaction) + if isinstance(losses, tuple): if epoch_idx < self.config['warm_up_step']: losses = losses[:-1] @@ -1244,10 +1256,12 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals loss = losses total_loss = losses.item() if total_loss is None else total_loss + losses.item() self._check_nan(loss) - (loss + sync_loss).backward() + scaler.scale(loss + sync_loss).backward() + if self.clip_grad_norm: clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm) - self.optimizer.step() + scaler.step(self.optimizer) + scaler.update() if self.gpu_available and show_progress: iter_data.set_postfix_str(set_color('GPU RAM: ' + get_gpu_usage(self.device), 'yellow')) return total_loss