-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathduorec.py
executable file
·302 lines (243 loc) · 13.2 KB
/
duorec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# -*- coding: utf-8 -*-
# @Time : 2020/9/18 11:33
# @Author : Shuqing Bian
# @Email : [email protected]
"""
SASRec
################################################
Reference:
Wang-Cheng Kang et al. "Self-Attentive Sequential Recommendation." in ICDM 2018.
Reference:
https://github.com/kang205/SASRec
"""
import torch
from torch import nn
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.layers import TransformerEncoder
from recbole.model.loss import BPRLoss
class DuoRec(SequentialRecommender):
r"""
SASRec is the first sequential recommender based on self-attentive mechanism.
NOTE:
In the author's implementation, the Point-Wise Feed-Forward Network (PFFN) is implemented
by CNN with 1x1 kernel. In this implementation, we follows the original BERT implementation
using Fully Connected Layer to implement the PFFN.
"""
def __init__(self, config, dataset):
super(DuoRec, self).__init__(config, dataset)
# load parameters info
self.n_layers = config['n_layers']
self.n_heads = config['n_heads']
self.hidden_size = config['hidden_size'] # same as embedding_size
self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer
self.hidden_dropout_prob = config['hidden_dropout_prob']
self.attn_dropout_prob = config['attn_dropout_prob']
self.hidden_act = config['hidden_act']
self.layer_norm_eps = config['layer_norm_eps']
self.lmd = config['lmd']
self.lmd_sem = config['lmd_sem']
self.initializer_range = config['initializer_range']
self.loss_type = config['loss_type']
# define layers and loss
self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0)
self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size)
self.trm_encoder = TransformerEncoder(
n_layers=self.n_layers,
n_heads=self.n_heads,
hidden_size=self.hidden_size,
inner_size=self.inner_size,
hidden_dropout_prob=self.hidden_dropout_prob,
attn_dropout_prob=self.attn_dropout_prob,
hidden_act=self.hidden_act,
layer_norm_eps=self.layer_norm_eps
)
self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.dropout = nn.Dropout(self.hidden_dropout_prob)
if self.loss_type == 'BPR':
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
self.ssl = config['contrast']
self.tau = config['tau']
self.sim = config['sim']
self.batch_size = config['train_batch_size']
self.mask_default = self.mask_correlated_samples(batch_size=self.batch_size)
self.aug_nce_fct = nn.CrossEntropyLoss()
self.sem_aug_nce_fct = nn.CrossEntropyLoss()
# parameters initialization
self.apply(self._init_weights)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
# module.weight.data = self.truncated_normal_(tensor=module.weight.data, mean=0, std=self.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def truncated_normal_(self, tensor, mean=0, std=0.09):
with torch.no_grad():
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
return tensor
def get_attention_mask(self, item_seq):
"""Generate left-to-right uni-directional attention mask for multi-head attention."""
attention_mask = (item_seq > 0).long()
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64
# mask for left-to-right unidirectional
max_len = attention_mask.size(-1)
attn_shape = (1, max_len, max_len)
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8
subsequent_mask = (subsequent_mask == 0).unsqueeze(1)
subsequent_mask = subsequent_mask.long().to(item_seq.device)
extended_attention_mask = extended_attention_mask * subsequent_mask
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def get_bi_attention_mask(self, item_seq):
"""Generate bidirectional attention mask for multi-head attention."""
attention_mask = (item_seq > 0).long()
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64
# bidirectional mask
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(self, item_seq, item_seq_len):
position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device)
position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
position_embedding = self.position_embedding(position_ids)
item_emb = self.item_embedding(item_seq)
input_emb = item_emb + position_embedding
input_emb = self.LayerNorm(input_emb)
input_emb = self.dropout(input_emb)
extended_attention_mask = self.get_attention_mask(item_seq)
# extended_attention_mask = self.get_bi_attention_mask(item_seq)
trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True)
output = trm_output[-1]
output = self.gather_indexes(output, item_seq_len - 1)
return output # [B H]
def calculate_loss(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
# Unsupervised NCE
if self.ssl in ['us', 'un']:
aug_seq_output = self.forward(item_seq, item_seq_len)
nce_logits, nce_labels = self.info_nce(seq_output, aug_seq_output, temp=self.tau,
batch_size=item_seq_len.shape[0], sim=self.sim)
# nce_logits = torch.mm(seq_output, aug_seq_output.T)
# nce_labels = torch.tensor(list(range(nce_logits.shape[0])), dtype=torch.long, device=item_seq.device)
# if self.ssl == 'un':
# with torch.no_grad():
# alignment, uniformity = self.decompose(seq_output, aug_seq_output, seq_output,
# batch_size=item_seq_len.shape[0])
loss += self.lmd * self.aug_nce_fct(nce_logits, nce_labels)
# Supervised NCE
if self.ssl in ['us', 'su']:
sem_aug, sem_aug_lengths = interaction['sem_aug'], interaction['sem_aug_lengths']
sem_aug_seq_output = self.forward(sem_aug, sem_aug_lengths)
sem_nce_logits, sem_nce_labels = self.info_nce(seq_output, sem_aug_seq_output, temp=self.tau,
batch_size=item_seq_len.shape[0], sim=self.sim)
# sem_nce_logits = torch.mm(seq_output, sem_aug_seq_output.T) / self.tau
# sem_nce_labels = torch.tensor(list(range(sem_nce_logits.shape[0])), dtype=torch.long, device=item_seq.device)
# with torch.no_grad():
# alignment, uniformity = self.decompose(seq_output, sem_aug_seq_output, seq_output,
# batch_size=item_seq_len.shape[0])
loss += self.lmd_sem * self.aug_nce_fct(sem_nce_logits, sem_nce_labels)
if self.ssl == 'us_x':
aug_seq_output = self.forward(item_seq, item_seq_len)
sem_aug, sem_aug_lengths = interaction['sem_aug'], interaction['sem_aug_lengths']
sem_aug_seq_output = self.forward(sem_aug, sem_aug_lengths)
sem_nce_logits, sem_nce_labels = self.info_nce(aug_seq_output, sem_aug_seq_output, temp=self.tau,
batch_size=item_seq_len.shape[0], sim=self.sim)
loss += self.lmd_sem * self.aug_nce_fct(sem_nce_logits, sem_nce_labels)
# with torch.no_grad():
# alignment, uniformity = self.decompose(aug_seq_output, sem_aug_seq_output, seq_output,
# batch_size=item_seq_len.shape[0])
return loss
def mask_correlated_samples(self, batch_size):
N = 2 * batch_size
mask = torch.ones((N, N), dtype=bool)
mask = mask.fill_diagonal_(0)
for i in range(batch_size):
mask[i, batch_size + i] = 0
mask[batch_size + i, i] = 0
return mask
def info_nce(self, z_i, z_j, temp, batch_size, sim='dot'):
"""
We do not sample negative examples explicitly.
Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
"""
N = 2 * batch_size
z = torch.cat((z_i, z_j), dim=0)
if sim == 'cos':
sim = nn.functional.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) / temp
elif sim == 'dot':
sim = torch.mm(z, z.T) / temp
sim_i_j = torch.diag(sim, batch_size)
sim_j_i = torch.diag(sim, -batch_size)
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
if batch_size != self.batch_size:
mask = self.mask_correlated_samples(batch_size)
else:
mask = self.mask_default
negative_samples = sim[mask].reshape(N, -1)
labels = torch.zeros(N).to(positive_samples.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
return logits, labels
def decompose(self, z_i, z_j, origin_z, batch_size):
"""
We do not sample negative examples explicitly.
Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
"""
N = 2 * batch_size
z = torch.cat((z_i, z_j), dim=0)
# pairwise l2 distace
sim = torch.cdist(z, z, p=2)
sim_i_j = torch.diag(sim, batch_size)
sim_j_i = torch.diag(sim, -batch_size)
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
alignment = positive_samples.mean()
# pairwise l2 distace
sim = torch.cdist(origin_z, origin_z, p=2)
mask = torch.ones((batch_size, batch_size), dtype=bool)
mask = mask.fill_diagonal_(0)
negative_samples = sim[mask].reshape(batch_size, -1)
uniformity = torch.log(torch.exp(-2 * negative_samples).mean())
return alignment, uniformity
def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
test_item = interaction[self.ITEM_ID]
seq_output = self.forward(item_seq, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores
def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B n_items]
return scores