-
Notifications
You must be signed in to change notification settings - Fork 706
/
din.py
134 lines (105 loc) · 6.82 KB
/
din.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# -*- coding:utf-8 -*-
"""
Author:
Yuef Zhang
Reference:
[1] Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068. (https://arxiv.org/pdf/1706.06978.pdf)
"""
from .basemodel import BaseModel
from ..inputs import *
from ..layers import *
from ..layers.sequence import AttentionSequencePoolingLayer
class DIN(BaseModel):
"""Instantiates the Deep Interest Network architecture.
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
:param history_feature_list: list,to indicate sequence sparse field
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in deep net
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of deep net
:param dnn_activation: Activation function to use in deep net
:param att_hidden_size: list,list of positive integer , the layer number and units in each layer of attention net
:param att_activation: Activation function to use in attention net
:param att_weight_normalization: bool. Whether normalize the attention score of local activation unit.
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
:param init_std: float,to use as the initialize std of embedding vector
:param seed: integer ,to use as random seed.
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
:param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
"""
def __init__(self, dnn_feature_columns, history_feature_list, dnn_use_bn=False,
dnn_hidden_units=(256, 128), dnn_activation='relu', att_hidden_size=(64, 16),
att_activation='Dice', att_weight_normalization=False, l2_reg_dnn=0.0,
l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001,
seed=1024, task='binary', device='cpu', gpus=None):
super(DIN, self).__init__([], dnn_feature_columns, l2_reg_linear=0, l2_reg_embedding=l2_reg_embedding,
init_std=init_std, seed=seed, task=task, device=device, gpus=gpus)
self.sparse_feature_columns = list(
filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns)) if dnn_feature_columns else []
self.varlen_sparse_feature_columns = list(
filter(lambda x: isinstance(x, VarLenSparseFeat), dnn_feature_columns)) if dnn_feature_columns else []
self.history_feature_list = history_feature_list
self.history_feature_columns = []
self.sparse_varlen_feature_columns = []
self.history_fc_names = list(map(lambda x: "hist_" + x, history_feature_list))
for fc in self.varlen_sparse_feature_columns:
feature_name = fc.name
if feature_name in self.history_fc_names:
self.history_feature_columns.append(fc)
else:
self.sparse_varlen_feature_columns.append(fc)
att_emb_dim = self._compute_interest_dim()
self.attention = AttentionSequencePoolingLayer(att_hidden_units=att_hidden_size,
embedding_dim=att_emb_dim,
att_activation=att_activation,
return_score=False,
supports_masking=False,
weight_normalization=att_weight_normalization)
self.dnn = DNN(inputs_dim=self.compute_input_dim(dnn_feature_columns),
hidden_units=dnn_hidden_units,
activation=dnn_activation,
dropout_rate=dnn_dropout,
l2_reg=l2_reg_dnn,
use_bn=dnn_use_bn)
self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device)
self.to(device)
def forward(self, X):
_, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, self.embedding_dict)
# sequence pooling part
query_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_feature_columns,
return_feat_list=self.history_feature_list, to_list=True)
keys_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.history_feature_columns,
return_feat_list=self.history_fc_names, to_list=True)
dnn_input_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_feature_columns,
to_list=True)
sequence_embed_dict = varlen_embedding_lookup(X, self.embedding_dict, self.feature_index,
self.sparse_varlen_feature_columns)
sequence_embed_list = get_varlen_pooling_list(sequence_embed_dict, X, self.feature_index,
self.sparse_varlen_feature_columns, self.device)
dnn_input_emb_list += sequence_embed_list
deep_input_emb = torch.cat(dnn_input_emb_list, dim=-1)
# concatenate
query_emb = torch.cat(query_emb_list, dim=-1) # [B, 1, E]
keys_emb = torch.cat(keys_emb_list, dim=-1) # [B, T, E]
keys_length_feature_name = [feat.length_name for feat in self.varlen_sparse_feature_columns if
feat.length_name is not None]
keys_length = torch.squeeze(maxlen_lookup(X, self.feature_index, keys_length_feature_name), 1) # [B, 1]
hist = self.attention(query_emb, keys_emb, keys_length) # [B, 1, E]
# deep part
deep_input_emb = torch.cat((deep_input_emb, hist), dim=-1)
deep_input_emb = deep_input_emb.view(deep_input_emb.size(0), -1)
dnn_input = combined_dnn_input([deep_input_emb], dense_value_list)
dnn_output = self.dnn(dnn_input)
dnn_logit = self.dnn_linear(dnn_output)
y_pred = self.out(dnn_logit)
return y_pred
def _compute_interest_dim(self):
interest_dim = 0
for feat in self.sparse_feature_columns:
if feat.name in self.history_feature_list:
interest_dim += feat.embedding_dim
return interest_dim
if __name__ == '__main__':
pass