-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathMetaEmb.py
163 lines (115 loc) · 6.44 KB
/
MetaEmb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# @Time : 2022/4/6
# @Author : Zeyu Zhang
# @Email : [email protected]
"""
recbole.MetaModule.model.MetaEmb
##########################
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from collections import OrderedDict
from recbole.model.layers import MLPLayers
from recbole.utils import InputType, FeatureSource, FeatureType
from recbole_metarec.MetaRecommender import MetaRecommender
from recbole_metarec.MetaUtils import GradCollector,EmbeddingTable
class ModelRec(nn.Module):
def __init__(self,indexEmbDim,embeddingSize,dataset,hiddenDim):
super(ModelRec, self).__init__()
self.userEmbedding=EmbeddingTable(embeddingSize,dataset,source=[FeatureSource.USER])
self.itemEmbedding=EmbeddingTable(embeddingSize,dataset,source=[FeatureSource.ITEM])
self.hiddenLayer=nn.Linear(indexEmbDim+self.userEmbedding.getAllDim()+self.itemEmbedding.getAllDim(),hiddenDim)
self.outputLayer=nn.Linear(hiddenDim,2)
def forward(self,indexEmb,userFeatures,itemFeatures):
input_x=torch.cat([indexEmb,self.userEmbedding.embeddingAllFields(userFeatures),self.itemEmbedding.embeddingAllFields(itemFeatures)],dim=1)
return F.softmax(self.outputLayer(F.relu(self.hiddenLayer(input_x))))
class PreTrainModel(nn.Module):
def __init__(self,config,dataset):
super(PreTrainModel, self).__init__()
self.embedding_size = config['embedding']
self.indexEmbDim = config['indexEmbDim']
self.indexEmbedding = nn.Embedding(dataset.num(config['USER_ID_FIELD']), self.indexEmbDim)
self.f = ModelRec(self.indexEmbDim, self.embedding_size, dataset, config['modelRecHiddenDim'])
class EmbeddingGenerator(nn.Module):
def __init__(self,userEmbedding,embeddingDim,hiddenDim,indexEmbDim):
super(EmbeddingGenerator, self).__init__()
self.userEmbedding=deepcopy(userEmbedding)
self.mlp=nn.Sequential(
nn.Linear(self.userEmbedding.getAllDim(), hiddenDim),
nn.ReLU(),
nn.Linear(hiddenDim, indexEmbDim)
)
def forward(self,userFeatures):
indexEmb=self.mlp(self.userEmbedding.embeddingAllFields(userFeatures))
return indexEmb
class MetaEmb(MetaRecommender):
'''
This is the recommender implement of MetaEmb.
Pan F, Li S, Ao X, et al. Warm up cold-start advertisements: Improving ctr predictions via learning to learn id embeddings[C]
Proceedings of the 42nd International ACM SIGIR Conference on Research and Development in Information Retrieval. 2019: 695-704.
https://doi.org/10.1145/3331184.3331268
'''
input_type = InputType.POINTWISE
def __init__(self,config,dataset):
super(MetaEmb, self).__init__(config,dataset)
self.device=self.config.final_config_dict['device']
self.embedding_size=self.config['embedding']
self.indexEmbDim=self.config['indexEmbDim']
self.embeddingGeneratorHiddenDim=self.config['embeddingGeneratorHiddenDim']
self.localLr=self.config['localLr']
self.alpha=self.config['alpha']
self.pretrainModel=PreTrainModel(config,dataset)
self.pretrainOpt=torch.optim.SGD(self.pretrainModel.parameters(),lr=self.config['pretrainLr'])
self.embeddingGenerator=EmbeddingGenerator(self.pretrainModel.f.userEmbedding,self.embedding_size,self.embeddingGeneratorHiddenDim,self.indexEmbDim)
self.metaGradCollector = GradCollector(list(self.embeddingGenerator.mlp.state_dict().keys()))
def pretrain(self,taskBatch):
for task in taskBatch:
(spt_x_userid,spt_x_user, spt_x_item), spt_y,(qrt_x_userid,qrt_x_user, qrt_x_item), qrt_y =task
spt_x_indexEmbedding=self.pretrainModel.indexEmbedding(spt_x_userid)
predict_spt_y=self.pretrainModel.f(spt_x_indexEmbedding,spt_x_user,spt_x_item)
spt_y=spt_y-1
spt_loss=F.cross_entropy(predict_spt_y,spt_y)
qrt_x_indexEmbedding=self.pretrainModel.indexEmbedding(qrt_x_userid)
predict_qrt_y=self.pretrainModel.f(qrt_x_indexEmbedding,qrt_x_user,qrt_x_item)
qrt_y=qrt_y-1
qrt_loss=F.cross_entropy(predict_qrt_y,qrt_y)
loss=spt_loss+qrt_loss
self.pretrainOpt.zero_grad()
loss.backward()
self.pretrainOpt.step()
def forward(self,spt_x,spt_y,qrt_x):
(spt_x_userid,spt_x_user, spt_x_item), spt_y,(qrt_x_userid,qrt_x_user, qrt_x_item)=spt_x,spt_y,qrt_x
phi_init = self.embeddingGenerator(spt_x_user)
predict_spt_y = self.pretrainModel.f(phi_init, spt_x_user, spt_x_item)
spt_y = spt_y - 1
spt_loss = F.cross_entropy(predict_spt_y, spt_y)
grad = torch.autograd.grad(spt_loss, phi_init)
avgGrad = torch.sum(grad[0], dim=0) / grad[0].shape[0]
phi_prime = (phi_init[0] - self.localLr * avgGrad) + torch.zeros(size=(qrt_x_userid.shape[0], avgGrad.shape[0])).to(self.device)
predict_qrt_y = self.pretrainModel.f(phi_prime, qrt_x_user, qrt_x_item)
return predict_qrt_y
def calculate_loss(self, taskBatch):
totalLoss = torch.tensor(0.0).to(self.device)
for task in taskBatch:
(spt_x_userid,spt_x_user, spt_x_item), spt_y,(qrt_x_userid,qrt_x_user, qrt_x_item), qrt_y =task
phi_init=self.embeddingGenerator(spt_x_user)
predict_spt_y=self.pretrainModel.f(phi_init,spt_x_user,spt_x_item)
spt_y = spt_y - 1
spt_loss=F.cross_entropy(predict_spt_y,spt_y)
grad=torch.autograd.grad(spt_loss,phi_init,create_graph=True,retain_graph=True)
avgGrad=torch.sum(grad[0].to(self.device),dim=0)/grad[0].shape[0]
phi_prime=(phi_init[0]-self.localLr*avgGrad)+torch.zeros(size=(qrt_x_userid.shape[0],avgGrad.shape[0])).to(self.device)
predict_qrt_y=self.pretrainModel.f(phi_prime,qrt_x_user,qrt_x_item)
qrt_y=qrt_y-1
qrt_loss=F.cross_entropy(predict_qrt_y,qrt_y)
loss=self.alpha*spt_loss+(1-self.alpha)*qrt_loss
grad=torch.autograd.grad(loss,self.embeddingGenerator.mlp.parameters())
self.metaGradCollector.addGrad(grad)
totalLoss+=loss.detach()
self.metaGradCollector.averageGrad(self.config['train_batch_size'])
totalLoss /= self.config['train_batch_size']
return totalLoss, self.metaGradCollector.dumpGrad()
def predict(self, spt_x,spt_y,qrt_x):
predict_qrt_y=self.forward(spt_x,spt_y,qrt_x)[:,1]
return predict_qrt_y