Skip to content

Commit

Permalink
[feat]:support only sequence feature && fix neg sampler bug for seque…
Browse files Browse the repository at this point in the history
…nce feature (#264)

* fix neg sampler  bug for sequence feature
* add need_key_feature
  • Loading branch information
lgqfhwy authored Aug 5, 2022
1 parent d0ac898 commit 160ed9c
Show file tree
Hide file tree
Showing 7 changed files with 984 additions and 2 deletions.
3 changes: 3 additions & 0 deletions docs/source/feature/feature.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ Sequense类特征格式一般为“XX\|XX\|XX”,如用户行为序列特征
sequence_features: {
group_name: "seq_fea"
allow_key_search: true
need_key_feature:true
seq_att_map: {
key: "brand"
key: "cate_id"
Expand All @@ -281,6 +282,8 @@ Sequense类特征格式一般为“XX\|XX\|XX”,如用户行为序列特征
- sequence_features: 序列特征组的名称
- allow_key_search: 当 key 对应的特征没有在 feature_groups 里面时,需要设置为 true, 将会复用对应特征的 embedding.
- need_key_feature : 默认为 true, 指过完 target attention 之后的特征会和 key 对应的特征 concat 之后返回。
设置为 false 时,将会只返回过完 target attention 之后的特征。
- seq_att_map: 具体细节可以参考排序里的 DIN 模型。
- NOTE:SequenceFeature一般放在 user 组里面。

Expand Down
11 changes: 9 additions & 2 deletions easy_rec/python/layers/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ def __init__(self,
def has_group(self, group_name):
return group_name in self._feature_groups

def target_attention(self, dnn_config, deep_fea, name):
def target_attention(self, dnn_config, deep_fea, name, need_key_feature=True):
cur_id, hist_id_col, seq_len = deep_fea['key'], deep_fea[
'hist_seq_emb'], deep_fea['hist_seq_len']

seq_max_len = tf.shape(hist_id_col)[1]
emb_dim = hist_id_col.shape[2]

cur_id = cur_id[:tf.shape(hist_id_col)[0], ...] # for negative sampler
cur_ids = tf.tile(cur_id, [1, seq_max_len])
cur_ids = tf.reshape(cur_ids,
tf.shape(hist_id_col)) # (B, seq_max_len, emb_dim)
Expand All @@ -96,6 +97,8 @@ def target_attention(self, dnn_config, deep_fea, name):
scores = tf.nn.softmax(scores) # (B, 1, seq_max_len)
hist_din_emb = tf.matmul(scores, hist_id_col) # [B, 1, emb_dim]
hist_din_emb = tf.reshape(hist_din_emb, [-1, emb_dim]) # [B, emb_dim]
if not need_key_feature:
return hist_din_emb
din_output = tf.concat([hist_din_emb, cur_id], axis=1)
return din_output

Expand All @@ -108,6 +111,7 @@ def call_seq_input_layer(self,
for seq_att_map_config in all_seq_att_map_config:
group_name = seq_att_map_config.group_name
allow_key_search = seq_att_map_config.allow_key_search
need_key_feature = seq_att_map_config.need_key_feature
seq_features = self._seq_input_layer(features, group_name,
feature_name_to_output_tensors,
allow_key_search)
Expand All @@ -128,7 +132,10 @@ def call_seq_input_layer(self,
seq_dnn_config.hidden_units.extend([128, 64, 32, 1])
cur_target_attention_name = 'seq_dnn' + group_name
seq_fea = self.target_attention(
seq_dnn_config, seq_features, name=cur_target_attention_name)
seq_dnn_config,
seq_features,
name=cur_target_attention_name,
need_key_feature=need_key_feature)
all_seq_fea.append(seq_fea)
# concat all seq_fea
all_seq_fea = tf.concat(all_seq_fea, axis=1)
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/protos/feature_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,5 @@ message SeqAttGroupConfig {
optional bool tf_summary = 3 [default = false];
optional DNN seq_dnn = 4;
optional bool allow_key_search = 5 [default = false];
optional bool need_key_feature = 6 [default = true];
}
20 changes: 20 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,26 @@ def test_distribute_eval_esmm(self):
cur_eval_path, self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_dssm_neg_sampler_sequence_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_neg_sampler_sequence_feature.config',
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_dssm_neg_sampler_need_key_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_neg_sampler_need_key_feature.config',
self._test_dir)
self.assertTrue(self._success)

def test_dbmtl_on_multi_numeric_boundary_need_key_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dbmtl_on_multi_numeric_boundary_need_key_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)


if __name__ == '__main__':
tf.test.main()
Loading

0 comments on commit 160ed9c

Please sign in to comment.