Skip to content

Commit

Permalink
Merge pull request #1 from RUCAIBox/1.1.x
Browse files Browse the repository at this point in the history
1.1.x
  • Loading branch information
Ethan-TZ authored Jul 8, 2022
2 parents ee14ba3 + d5644c3 commit ebb138f
Show file tree
Hide file tree
Showing 26 changed files with 106 additions and 94 deletions.
2 changes: 2 additions & 0 deletions docs/source/user_guide/config/training_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ Training settings are designed to set parameters about model training.
- ``loss_decimal_place(int)``: The decimal place of training loss. Defaults to ``4``.
- ``weight_decay (float)`` : The weight decay (L2 penalty), used for `optimizer <https://pytorch.org/docs/stable/optim.html?highlight=weight_decay>`_. Default to ``0.0``.
- ``require_pow(bool)``: The sign identifies whether the power operation is performed based on the norm in EmbLoss. Defaults to ``False``.
- ``enable_amp(bool)``: The parameter determines whether to use mixed precision training . Defaults to ``False``.
- ``enable_scaler(bool)``: The parameter determines whether to use GradScaler that is often used with mixed precision training to avoid gradient precision overflow. Defaults to ``False``.
6 changes: 3 additions & 3 deletions recbole/model/context_aware_recommender/afm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, config, dataset):
self.p = nn.Parameter(torch.randn(self.embedding_size), requires_grad=True)
self.dropout_layer = nn.Dropout(p=self.dropout_prob)
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand Down Expand Up @@ -101,7 +101,7 @@ def afm_layer(self, infeature):
def forward(self, interaction):
afm_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim]

output = self.sigmoid(self.first_order_linear(interaction) + self.afm_layer(afm_all_embeddings))
output = self.first_order_linear(interaction) + self.afm_layer(afm_all_embeddings)
return output.squeeze(-1)

def calculate_loss(self, interaction):
Expand All @@ -112,4 +112,4 @@ def calculate_loss(self, interaction):
return self.loss(output, label) + l2_loss

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
6 changes: 3 additions & 3 deletions recbole/model/context_aware_recommender/autoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, config, dataset):

self.dropout_layer = nn.Dropout(p=self.dropout_probs[2])
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand Down Expand Up @@ -97,12 +97,12 @@ def autoint_layer(self, infeature):
def forward(self, interaction):
autoint_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim]
output = self.first_order_linear(interaction) + self.autoint_layer(autoint_all_embeddings)
return self.sigmoid(output.squeeze(1))
return output.squeeze(1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
output = self.forward(interaction)
return self.loss(output, label)

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
6 changes: 3 additions & 3 deletions recbole/model/context_aware_recommender/dcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, config, dataset):
self.predict_layer = nn.Linear(in_feature_num, 1)
self.reg_loss = RegLoss()
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand Down Expand Up @@ -108,7 +108,7 @@ def forward(self, interaction):
# Cross Network
cross_output = self.cross_network(dcn_all_embeddings)
stack = torch.cat([cross_output, deep_output], dim=-1)
output = self.sigmoid(self.predict_layer(stack))
output = self.predict_layer(stack)

return output.squeeze(1)

Expand All @@ -119,4 +119,4 @@ def calculate_loss(self, interaction):
return self.loss(output, label) + l2_loss

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
6 changes: 3 additions & 3 deletions recbole/model/context_aware_recommender/deepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, config, dataset):
self.mlp_layers = MLPLayers(size_list, self.dropout_prob)
self.deep_predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1) # Linear product to the final score
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand All @@ -61,7 +61,7 @@ def forward(self, interaction):
y_fm = self.first_order_linear(interaction) + self.fm(deepfm_all_embeddings)

y_deep = self.deep_predict_layer(self.mlp_layers(deepfm_all_embeddings.view(batch_size, -1)))
y = self.sigmoid(y_fm + y_deep)
y = y_fm + y_deep
return y.squeeze(-1)

def calculate_loss(self, interaction):
Expand All @@ -70,4 +70,4 @@ def calculate_loss(self, interaction):
return self.loss(output, label)

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
8 changes: 3 additions & 5 deletions recbole/model/context_aware_recommender/dssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, config, dataset):
self.user_mlp_layers = MLPLayers(user_size_list, self.dropout_prob, activation='tanh', bn=True)
self.item_mlp_layers = MLPLayers(item_size_list, self.dropout_prob, activation='tanh', bn=True)

self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()
self.sigmoid = nn.Sigmoid()

# parameters initialization
Expand Down Expand Up @@ -84,14 +84,12 @@ def forward(self, interaction):
user_dnn_out = self.user_mlp_layers(embed_user.view(batch_size, -1))
item_dnn_out = self.item_mlp_layers(embed_item.view(batch_size, -1))
score = torch.cosine_similarity(user_dnn_out, item_dnn_out, dim=1)

sig_score = self.sigmoid(score)
return sig_score.squeeze(-1)
return score.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
output = self.forward(interaction)
return self.loss(output, label)

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
6 changes: 3 additions & 3 deletions recbole/model/context_aware_recommender/ffm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, config, dataset):
self.feature_names, self.feature_dims, self.feature2id, self.feature2field, self.num_fields,
self.embedding_size, self.device
)
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand Down Expand Up @@ -116,7 +116,7 @@ def get_ffm_input(self, interaction):
def forward(self, interaction):
ffm_input = self.get_ffm_input(interaction)
ffm_output = torch.sum(torch.sum(self.ffm(ffm_input), dim=1), dim=1, keepdim=True)
output = self.sigmoid(self.first_order_linear(interaction) + ffm_output)
output = self.first_order_linear(interaction) + ffm_output

return output.squeeze(-1)

Expand All @@ -127,7 +127,7 @@ def calculate_loss(self, interaction):
return self.loss(output, label)

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))


class FieldAwareFactorizationMachine(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions recbole/model/context_aware_recommender/fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, config, dataset):
# define layers and loss
self.fm = BaseFactorizationMachine(reduce_sum=True)
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand All @@ -46,7 +46,7 @@ def _init_weights(self, module):

def forward(self, interaction):
fm_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim]
y = self.sigmoid(self.first_order_linear(interaction) + self.fm(fm_all_embeddings))
y = self.first_order_linear(interaction) + self.fm(fm_all_embeddings)
return y.squeeze(-1)

def calculate_loss(self, interaction):
Expand All @@ -56,4 +56,4 @@ def calculate_loss(self, interaction):
return self.loss(output, label)

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
5 changes: 2 additions & 3 deletions recbole/model/context_aware_recommender/fnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, config, dataset):
self.predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1, bias=True)

self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand All @@ -62,7 +62,6 @@ def forward(self, interaction):
batch_size = fnn_all_embeddings.shape[0]

output = self.predict_layer(self.mlp_layers(fnn_all_embeddings.view(batch_size, -1)))
output = self.sigmoid(output)
return output.squeeze(-1)

def calculate_loss(self, interaction):
Expand All @@ -72,4 +71,4 @@ def calculate_loss(self, interaction):
return self.loss(output, label)

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
6 changes: 3 additions & 3 deletions recbole/model/context_aware_recommender/fwfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, config, dataset):
self.num_fields = len(set(self.feature2field.values())) # the number of fields
self.num_pair = self.num_fields * self.num_fields

self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand Down Expand Up @@ -122,7 +122,7 @@ def fwfm_layer(self, infeature):
def forward(self, interaction):
fwfm_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim]

output = self.sigmoid(self.first_order_linear(interaction) + self.fwfm_layer(fwfm_all_embeddings))
output = self.first_order_linear(interaction) + self.fwfm_layer(fwfm_all_embeddings)

return output.squeeze(-1)

Expand All @@ -133,4 +133,4 @@ def calculate_loss(self, interaction):
return self.loss(output, label)

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
6 changes: 3 additions & 3 deletions recbole/model/context_aware_recommender/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, config, dataset):
super(LR, self).__init__(config, dataset)

self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand All @@ -42,7 +42,7 @@ def _init_weights(self, module):
xavier_normal_(module.weight.data)

def forward(self, interaction):
output = self.sigmoid(self.first_order_linear(interaction))
output = self.first_order_linear(interaction)
return output.squeeze(-1)

def calculate_loss(self, interaction):
Expand All @@ -52,4 +52,4 @@ def calculate_loss(self, interaction):
return self.loss(output, label)

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
5 changes: 2 additions & 3 deletions recbole/model/context_aware_recommender/nfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, config, dataset):
self.mlp_layers = MLPLayers(size_list, self.dropout_prob, activation='sigmoid', bn=True)
self.predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1, bias=False)
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand All @@ -55,7 +55,6 @@ def forward(self, interaction):
bn_nfm_all_embeddings = self.bn(self.fm(nfm_all_embeddings))

output = self.predict_layer(self.mlp_layers(bn_nfm_all_embeddings)) + self.first_order_linear(interaction)
output = self.sigmoid(output)
return output.squeeze(-1)

def calculate_loss(self, interaction):
Expand All @@ -64,4 +63,4 @@ def calculate_loss(self, interaction):
return self.loss(output, label)

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
5 changes: 2 additions & 3 deletions recbole/model/context_aware_recommender/pnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, config, dataset):
self.predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand Down Expand Up @@ -98,7 +98,6 @@ def forward(self, interaction):
output = torch.cat(output, dim=1) # [batch_size,d]

output = self.predict_layer(self.mlp_layers(output)) # [batch_size,1]
output = self.sigmoid(output)
return output.squeeze(-1)

def calculate_loss(self, interaction):
Expand All @@ -108,7 +107,7 @@ def calculate_loss(self, interaction):
return self.loss(output, label) + self.reg_loss()

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))


class InnerProductLayer(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions recbole/model/context_aware_recommender/widedeep.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, config, dataset):
self.mlp_layers = MLPLayers(size_list, self.dropout_prob)
self.deep_predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1)
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
self.apply(self._init_weights)
Expand All @@ -58,7 +58,7 @@ def forward(self, interaction):
fm_output = self.first_order_linear(interaction)

deep_output = self.deep_predict_layer(self.mlp_layers(widedeep_all_embeddings.view(batch_size, -1)))
output = self.sigmoid(fm_output + deep_output)
output = fm_output + deep_output
return output.squeeze(-1)

def calculate_loss(self, interaction):
Expand All @@ -67,4 +67,4 @@ def calculate_loss(self, interaction):
return self.loss(output, label)

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
8 changes: 4 additions & 4 deletions recbole/model/context_aware_recommender/xdeepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, config, dataset):

self.cin_linear = nn.Linear(self.final_len, 1)
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.loss = nn.BCEWithLogitsLoss()
self.apply(self._init_weights)

def _init_weights(self, module):
Expand Down Expand Up @@ -180,9 +180,9 @@ def forward(self, interaction):

# Get predicted score.
y_p = self.first_order_linear(interaction) + cin_output + dnn_output
y = self.sigmoid(y_p)

return y.squeeze(1)

return y_p.squeeze(1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand All @@ -191,4 +191,4 @@ def calculate_loss(self, interaction):
return self.loss(output, label) + self.reg_weight * l2_reg

def predict(self, interaction):
return self.forward(interaction)
return self.sigmoid(self.forward(interaction))
9 changes: 5 additions & 4 deletions recbole/model/general_recommender/cdae.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x_items, x_users):
h = torch.add(h_u, h_i)
h = self.h_act(h)
out = self.out_layer(h)
return self.o_act(out)
return out

def get_rating_matrix(self, user):
r"""Get a batch of user's feature with the user's id and history interaction matrix.
Expand All @@ -100,12 +100,12 @@ def calculate_loss(self, interaction):
predict = self.forward(x_items, x_users)

if self.loss_type == 'MSE':
predict=self.o_act(predict)
loss_func = nn.MSELoss(reduction='sum')
elif self.loss_type == 'BCE':
loss_func = nn.BCELoss(reduction='sum')
loss_func = nn.BCEWithLogitsLoss(reduction='sum')
else:
raise ValueError('Invalid loss_type, loss_type must in [MSE, BCE]')

loss = loss_func(predict, x_items)
# l1-regularization
loss += self.reg_weight_1 * (self.h_user.weight.norm(p=1) + self.h_item.weight.norm(p=1))
Expand All @@ -120,12 +120,13 @@ def predict(self, interaction):

items = self.get_rating_matrix(users)
scores = self.forward(items, users)

scores=self.o_act(scores)
return scores[[torch.arange(len(predict_items)).to(self.device), predict_items]]

def full_sort_predict(self, interaction):
users = interaction[self.USER_ID]

items = self.get_rating_matrix(users)
predict = self.forward(items, users)
predict=self.o_act(predict)
return predict.view(-1)
Loading

0 comments on commit ebb138f

Please sign in to comment.