Skip to content

Commit

Permalink
Bugfix:fix bug of embedding regularization (#336)
Browse files Browse the repository at this point in the history
* [bug fix]: fix bug of embedding regularization
  • Loading branch information
yangxudong authored Feb 1, 2023
1 parent 3981cd7 commit 6c8591b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
16 changes: 12 additions & 4 deletions easy_rec/python/layers/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
from easy_rec.python.utils import shape_utils

from easy_rec.python.compat.feature_column.feature_column import _SharedEmbeddingColumn # NOQA
from easy_rec.python.compat.feature_column.feature_column_v2 import EmbeddingColumn # NOQA
from easy_rec.python.compat.feature_column.feature_column_v2 import SharedEmbeddingColumn # NOQA

from easy_rec.python.compat.feature_column.feature_column import _SharedEmbeddingColumn # NOQA


class InputLayer(object):
Expand Down Expand Up @@ -167,7 +169,12 @@ def single_call_input_layer(self,
group_columns,
cols_to_output_tensors=cols_to_output_tensors,
feature_name_to_output_tensors=feature_name_to_output_tensors)
embedding_reg_lst = [output_features]
# embedding_reg_lst = [output_features]
embedding_reg_lst = []
for col, val in cols_to_output_tensors.items():
if isinstance(col, EmbeddingColumn) or isinstance(col,
SharedEmbeddingColumn):
embedding_reg_lst.append(val)
builder = feature_column._LazyBuilder(features)
seq_features = []
for column in sorted(group_seq_columns, key=lambda x: x.name):
Expand Down Expand Up @@ -227,8 +234,9 @@ def single_call_input_layer(self,
group_features = [cols_to_output_tensors[x] for x in group_columns] + \
[cols_to_output_tensors[x] for x in group_seq_columns]

regularizers.apply_regularization(
self._embedding_regularizer, weights_list=embedding_reg_lst)
if embedding_reg_lst:
regularizers.apply_regularization(
self._embedding_regularizer, weights_list=embedding_reg_lst)
return concat_features, group_features

def get_wide_deep_dict(self):
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
__version__ = '0.6.0'
__version__ = '0.6.1'

0 comments on commit 6c8591b

Please sign in to comment.