From 758c19163639201b78a58f1c12f2d27e9de27b25 Mon Sep 17 00:00:00 2001 From: Andreas Peintner Date: Tue, 1 Oct 2024 09:30:09 +0200 Subject: [PATCH 1/2] init trimlp --- .../model/sequential_recommender/trimlp.py | 162 ++++++++++++++++++ recbole/properties/model/TriMLP.yaml | 5 + trimlp.yaml | 2 + 3 files changed, 169 insertions(+) create mode 100644 recbole/model/sequential_recommender/trimlp.py create mode 100644 recbole/properties/model/TriMLP.yaml create mode 100644 trimlp.yaml diff --git a/recbole/model/sequential_recommender/trimlp.py b/recbole/model/sequential_recommender/trimlp.py new file mode 100644 index 000000000..4f2021985 --- /dev/null +++ b/recbole/model/sequential_recommender/trimlp.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/09/26 12:19 +# @Author : Andreas Peintner +# @Email : anpeintner@gmail.com + +r""" +TriMLP +################################################ + +Reference: + Jiang et al. "TriMLP: A Foundational MLP-like Architecture for Sequential Recommendation" in TOIS 2024. + +Reference code: + https://github.com/jiangyiheng1/TriMLP/ +""" + +import torch +from torch import nn +from torch.nn.init import xavier_uniform_, xavier_normal_, constant_ + +from recbole.model.abstract_recommender import SequentialRecommender +from recbole.model.loss import BPRLoss + +def global_kernel(seq_len): + mask = torch.triu(torch.ones([seq_len, seq_len])) + matrix = torch.ones([seq_len, seq_len]) + matrix = matrix.masked_fill(mask == 0.0, -1e9) + kernel = nn.parameter.Parameter(matrix, requires_grad=True) + return kernel + + +def local_kernel(seq_len, n_session): + mask = torch.zeros([seq_len, seq_len]) + for i in range(0, seq_len, seq_len // n_session): + mask[i:i + seq_len // n_session, i:i + seq_len // n_session] = torch.ones( + [seq_len // n_session, seq_len // n_session]) + mask = torch.triu(mask) + matrix = torch.ones([seq_len, seq_len]) + matrix = matrix.masked_fill(mask == 0.0, -1e9) + kernel = nn.parameter.Parameter(matrix, requires_grad=True) + return kernel + +class TriMixer(nn.Module): + def __init__(self, seq_len, n_session, act=nn.Sigmoid()): + super().__init__() + assert seq_len % n_session == 0 + self.l = seq_len + self.n_s = n_session + self.act = act + self.local_mixing = local_kernel(self.l, self.n_s) + self.global_mixing = global_kernel(self.l) + + def forward(self, x): + x = torch.matmul(x.permute(0, 2, 1), self.global_mixing.softmax(dim=-1)) + if self.act: + x = self.act(x) + + x = torch.matmul(x, self.local_mixing.softmax(dim=-1)).permute(0, 2, 1) + if self.act: + x = self.act(x) + + return x + + def extra_repr(self): + return f"seq_len={self.l}, n_session={self.n_s}, act={self.act}" + + +class TriMLP(SequentialRecommender): + r"""TriMLP: A Foundational MLP-like Architecture for Sequential Recommendation + """ + + def __init__(self, config, dataset): + super(TriMLP, self).__init__(config, dataset) + + # load parameters info + self.embedding_size = config["embedding_size"] + self.loss_type = config["loss_type"] + + if config["act_fct"] == "sigmoid": + self.act_fct = nn.Sigmoid() + elif config["act_fct"] == "tanh": + self.act_fct = nn.Tanh() + else: + self.act_fct = None + + self.dropout_prob = config["dropout_prob"] + + self.num_session = config["num_session"] + + # define layers and loss + self.item_embedding = nn.Embedding( + self.n_items, self.embedding_size, padding_idx=0 + ) + self.emb_dropout = nn.Dropout(self.dropout_prob) + self.mixer = TriMixer(self.max_seq_length, self.num_session, act=self.act_fct) + self.final_layer = nn.Linear(self.embedding_size, self.embedding_size) + + if self.loss_type == "BPR": + self.loss_fct = BPRLoss() + elif self.loss_type == "CE": + self.loss_fct = nn.CrossEntropyLoss() + else: + raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") + + # parameters initialization + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): + xavier_normal_(module.weight.data) + elif isinstance(module, nn.Linear): + xavier_normal_(module.weight.data) + if module.bias is not None: + constant_(module.bias.data, 0) + + def forward(self, item_seq, item_seq_len): + item_seq_emb = self.item_embedding(item_seq) + item_seq_emb_dropout = self.emb_dropout(item_seq_emb) + + mixer_output = self.mixer(item_seq_emb_dropout) + seq_output = self.gather_indexes(mixer_output, item_seq_len - 1) + seq_output = self.final_layer(seq_output) + + return seq_output + + def calculate_loss(self, interaction): + item_seq = interaction[self.ITEM_SEQ] + item_seq_len = interaction[self.ITEM_SEQ_LEN] + seq_output = self.forward(item_seq, item_seq_len) + pos_items = interaction[self.POS_ITEM_ID] + if self.loss_type == "BPR": + neg_items = interaction[self.NEG_ITEM_ID] + pos_items_emb = self.item_embedding(pos_items) + neg_items_emb = self.item_embedding(neg_items) + pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B] + neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B] + loss = self.loss_fct(pos_score, neg_score) + return loss + else: # self.loss_type = 'CE' + test_item_emb = self.item_embedding.weight + logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) + loss = self.loss_fct(logits, pos_items) + return loss + + def predict(self, interaction): + item_seq = interaction[self.ITEM_SEQ] + item_seq_len = interaction[self.ITEM_SEQ_LEN] + test_item = interaction[self.ITEM_ID] + seq_output = self.forward(item_seq, item_seq_len) + test_item_emb = self.item_embedding(test_item) + scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] + return scores + + def full_sort_predict(self, interaction): + item_seq = interaction[self.ITEM_SEQ] + item_seq_len = interaction[self.ITEM_SEQ_LEN] + seq_output = self.forward(item_seq, item_seq_len) + test_items_emb = self.item_embedding.weight + scores = torch.matmul( + seq_output, test_items_emb.transpose(0, 1) + ) # [B, n_items] + return scores diff --git a/recbole/properties/model/TriMLP.yaml b/recbole/properties/model/TriMLP.yaml new file mode 100644 index 000000000..eeef94beb --- /dev/null +++ b/recbole/properties/model/TriMLP.yaml @@ -0,0 +1,5 @@ +embedding_size: 64 +act_fct: sigmoid # None or sigmoid or tanh +num_session: 2 +dropout_prob: 0.5 +loss_type: 'CE' \ No newline at end of file diff --git a/trimlp.yaml b/trimlp.yaml new file mode 100644 index 000000000..cb9eb5236 --- /dev/null +++ b/trimlp.yaml @@ -0,0 +1,2 @@ +train_neg_sample_args: +act_fct: None \ No newline at end of file From fee904fb9f1fd6667cb0179c64edb6c1579b1f47 Mon Sep 17 00:00:00 2001 From: Andreas Peintner Date: Tue, 1 Oct 2024 10:37:47 +0200 Subject: [PATCH 2/2] final layer --- .../model/sequential_recommender/trimlp.py | 52 ++++--------------- recbole/properties/model/TriMLP.yaml | 2 +- trimlp.yaml | 2 - 3 files changed, 10 insertions(+), 46 deletions(-) delete mode 100644 trimlp.yaml diff --git a/recbole/model/sequential_recommender/trimlp.py b/recbole/model/sequential_recommender/trimlp.py index 4f2021985..2ff58299b 100644 --- a/recbole/model/sequential_recommender/trimlp.py +++ b/recbole/model/sequential_recommender/trimlp.py @@ -16,10 +16,8 @@ import torch from torch import nn -from torch.nn.init import xavier_uniform_, xavier_normal_, constant_ from recbole.model.abstract_recommender import SequentialRecommender -from recbole.model.loss import BPRLoss def global_kernel(seq_len): mask = torch.triu(torch.ones([seq_len, seq_len])) @@ -84,7 +82,7 @@ def __init__(self, config, dataset): self.act_fct = None self.dropout_prob = config["dropout_prob"] - + self.final_softmax = config["final_softmax"] self.num_session = config["num_session"] # define layers and loss @@ -93,25 +91,9 @@ def __init__(self, config, dataset): ) self.emb_dropout = nn.Dropout(self.dropout_prob) self.mixer = TriMixer(self.max_seq_length, self.num_session, act=self.act_fct) - self.final_layer = nn.Linear(self.embedding_size, self.embedding_size) - - if self.loss_type == "BPR": - self.loss_fct = BPRLoss() - elif self.loss_type == "CE": - self.loss_fct = nn.CrossEntropyLoss() - else: - raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") - - # parameters initialization - self.apply(self._init_weights) + self.final_layer = nn.Linear(self.embedding_size, self.n_items) - def _init_weights(self, module): - if isinstance(module, nn.Embedding): - xavier_normal_(module.weight.data) - elif isinstance(module, nn.Linear): - xavier_normal_(module.weight.data) - if module.bias is not None: - constant_(module.bias.data, 0) + self.loss_fct = nn.CrossEntropyLoss(ignore_index=0) def forward(self, item_seq, item_seq_len): item_seq_emb = self.item_embedding(item_seq) @@ -126,37 +108,21 @@ def forward(self, item_seq, item_seq_len): def calculate_loss(self, interaction): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] - seq_output = self.forward(item_seq, item_seq_len) + scores = self.forward(item_seq, item_seq_len) pos_items = interaction[self.POS_ITEM_ID] - if self.loss_type == "BPR": - neg_items = interaction[self.NEG_ITEM_ID] - pos_items_emb = self.item_embedding(pos_items) - neg_items_emb = self.item_embedding(neg_items) - pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B] - neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B] - loss = self.loss_fct(pos_score, neg_score) - return loss - else: # self.loss_type = 'CE' - test_item_emb = self.item_embedding.weight - logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) - loss = self.loss_fct(logits, pos_items) - return loss + loss = self.loss_fct(scores, pos_items) + return loss def predict(self, interaction): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] test_item = interaction[self.ITEM_ID] - seq_output = self.forward(item_seq, item_seq_len) - test_item_emb = self.item_embedding(test_item) - scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] + scores = self.forward(item_seq, item_seq_len).unsqueeze(-1) + scores = self.gather_indexes(scores, test_item).squeeze(-1) return scores def full_sort_predict(self, interaction): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] - seq_output = self.forward(item_seq, item_seq_len) - test_items_emb = self.item_embedding.weight - scores = torch.matmul( - seq_output, test_items_emb.transpose(0, 1) - ) # [B, n_items] + scores = self.forward(item_seq, item_seq_len) return scores diff --git a/recbole/properties/model/TriMLP.yaml b/recbole/properties/model/TriMLP.yaml index eeef94beb..4ccac72c9 100644 --- a/recbole/properties/model/TriMLP.yaml +++ b/recbole/properties/model/TriMLP.yaml @@ -1,5 +1,5 @@ embedding_size: 64 -act_fct: sigmoid # None or sigmoid or tanh +act_fct: None # None or sigmoid or tanh num_session: 2 dropout_prob: 0.5 loss_type: 'CE' \ No newline at end of file diff --git a/trimlp.yaml b/trimlp.yaml deleted file mode 100644 index cb9eb5236..000000000 --- a/trimlp.yaml +++ /dev/null @@ -1,2 +0,0 @@ -train_neg_sample_args: -act_fct: None \ No newline at end of file