-
Notifications
You must be signed in to change notification settings - Fork 627
/
Copy pathgru4rec.py
119 lines (100 loc) · 4.51 KB
/
gru4rec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*- coding: utf-8 -*-
# @Time : 2020/8/17 19:38
# @Author : Yujie Lu
# @Email : [email protected]
# UPDATE:
# @Time : 2020/8/19, 2020/10/2
# @Author : Yupeng Hou, Yujie Lu
# @Email : [email protected], [email protected]
r"""
GRU4Rec
################################################
Reference:
Yong Kiam Tan et al. "Improved Recurrent Neural Networks for Session-based Recommendations." in DLRS 2016.
"""
import torch
from torch import nn
from torch.nn.init import xavier_uniform_, xavier_normal_
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.loss import BPRLoss
class GRU4Rec(SequentialRecommender):
r"""GRU4Rec is a model that incorporate RNN for recommendation.
Note:
Regarding the innovation of this article,we can only achieve the data augmentation mentioned
in the paper and directly output the embedding of the item,
in order that the generation method we used is common to other sequential models.
"""
def __init__(self, config, dataset):
super(GRU4Rec, self).__init__(config, dataset)
# load parameters info
self.embedding_size = config['embedding_size']
self.hidden_size = config['hidden_size']
self.loss_type = config['loss_type']
self.num_layers = config['num_layers']
self.dropout_prob = config['dropout_prob']
# define layers and loss
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
self.emb_dropout = nn.Dropout(self.dropout_prob)
self.gru_layers = nn.GRU(
input_size=self.embedding_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
bias=False,
batch_first=True,
)
self.dense = nn.Linear(self.hidden_size, self.embedding_size)
if self.loss_type == 'BPR':
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
# parameters initialization
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
xavier_normal_(module.weight)
elif isinstance(module, nn.GRU):
xavier_uniform_(module.weight_hh_l0)
xavier_uniform_(module.weight_ih_l0)
def forward(self, item_seq, item_seq_len):
item_seq_emb = self.item_embedding(item_seq)
item_seq_emb_dropout = self.emb_dropout(item_seq_emb)
gru_output, _ = self.gru_layers(item_seq_emb_dropout)
gru_output = self.dense(gru_output)
# the embedding of the predicted item, shape of (batch_size, embedding_size)
seq_output = self.gather_indexes(gru_output, item_seq_len - 1)
return seq_output
def calculate_loss(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
return loss
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
return loss
def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
test_item = interaction[self.ITEM_ID]
seq_output = self.forward(item_seq, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores
def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items]
return scores