diff --git a/recbole/model/general_recommender/ngcf.py b/recbole/model/general_recommender/ngcf.py index d315b85ba..aa26961cd 100644 --- a/recbole/model/general_recommender/ngcf.py +++ b/recbole/model/general_recommender/ngcf.py @@ -57,6 +57,7 @@ def __init__(self, config, dataset): self.sparse_dropout = SparseDropout(self.node_dropout) self.user_embedding = nn.Embedding(self.n_users, self.embedding_size) self.item_embedding = nn.Embedding(self.n_items, self.embedding_size) + self.emb_dropout = nn.Dropout(self.message_dropout) self.GNNlayers = torch.nn.ModuleList() for idx, (input_size, output_size) in enumerate( zip(self.hidden_size_list[:-1], self.hidden_size_list[1:]) @@ -157,7 +158,7 @@ def forward(self): for gnn in self.GNNlayers: all_embeddings = gnn(A_hat, self.eye_matrix, all_embeddings) all_embeddings = nn.LeakyReLU(negative_slope=0.2)(all_embeddings) - all_embeddings = nn.Dropout(self.message_dropout)(all_embeddings) + all_embeddings = self.emb_dropout(all_embeddings) all_embeddings = F.normalize(all_embeddings, p=2, dim=1) embeddings_list += [ all_embeddings diff --git a/recbole/model/general_recommender/simplex.py b/recbole/model/general_recommender/simplex.py index c47602591..9fa9e12c3 100644 --- a/recbole/model/general_recommender/simplex.py +++ b/recbole/model/general_recommender/simplex.py @@ -74,7 +74,7 @@ def __init__(self, config, dataset): if self.aggregator == "self_attention": self.W_q = nn.Linear(self.embedding_size, 1, bias=False) # dropout - self.dropout = nn.Dropout(0.1) + self.dropout_prob = nn.Dropout(config["dropout_prob"]) self.require_pow = config["require_pow"] # l2 regularization loss self.reg_loss = EmbLoss() diff --git a/recbole/model/knowledge_aware_recommender/mkr.py b/recbole/model/knowledge_aware_recommender/mkr.py index d032b09db..5a1fe4e98 100644 --- a/recbole/model/knowledge_aware_recommender/mkr.py +++ b/recbole/model/knowledge_aware_recommender/mkr.py @@ -73,7 +73,7 @@ def __init__(self, config, dataset): self.kge_pred_mlp = MLPLayers( [self.embedding_size * 2, self.embedding_size], self.dropout_prob, "sigmoid" ) - if self.use_inner_product == False: + if not self.use_inner_product: self.rs_pred_mlp = MLPLayers( [self.embedding_size * 2, 1], self.dropout_prob, "sigmoid" ) diff --git a/recbole/model/sequential_recommender/gru4reckg.py b/recbole/model/sequential_recommender/gru4reckg.py index 7aceb58bb..4e297cf13 100644 --- a/recbole/model/sequential_recommender/gru4reckg.py +++ b/recbole/model/sequential_recommender/gru4reckg.py @@ -47,6 +47,8 @@ def __init__(self, config, dataset): self.entity_embedding = nn.Embedding( self.n_items, self.embedding_size, padding_idx=0 ) + self.item_emb_dropout = nn.Dropout(self.dropout) + self.entity_emb_dropout = nn.Dropout(self.dropout) self.entity_embedding.weight.requires_grad = not self.freeze_kg self.item_gru_layers = nn.GRU( input_size=self.embedding_size, @@ -79,8 +81,8 @@ def __init__(self, config, dataset): def forward(self, item_seq, item_seq_len): item_emb = self.item_embedding(item_seq) entity_emb = self.entity_embedding(item_seq) - item_emb = nn.Dropout(self.dropout)(item_emb) - entity_emb = nn.Dropout(self.dropout)(entity_emb) + item_emb = self.item_emb_dropout(item_emb) + entity_emb = self.entity_emb_dropout(entity_emb) item_gru_output, _ = self.item_gru_layers(item_emb) # [B Len H] entity_gru_output, _ = self.entity_gru_layers(entity_emb) diff --git a/recbole/properties/model/SimpleX.yaml b/recbole/properties/model/SimpleX.yaml index c16360217..2f6ff7b75 100644 --- a/recbole/properties/model/SimpleX.yaml +++ b/recbole/properties/model/SimpleX.yaml @@ -4,4 +4,5 @@ negative_weight: 10 # (int) Weight to balance between positive-sampl gamma: 0.5 # (float) Weight for fusion of user' and interacted items' representations. aggregator: 'mean' # (str) The item aggregator ranging in ['mean', 'user_attention', 'self_attention']. history_len: 50 # (int) The length of the user's historical interaction items. -reg_weight: 1e-05 # (float) The L2 regularization weights. \ No newline at end of file +reg_weight: 1e-05 # (float) The L2 regularization weights. +dropout_prob: 0.1 # (float) Dropout probability for fusion of user' and interacted items' representations. \ No newline at end of file