diff --git a/easy_rec/python/layers/input_layer.py b/easy_rec/python/layers/input_layer.py index caeda193d..731f47c82 100644 --- a/easy_rec/python/layers/input_layer.py +++ b/easy_rec/python/layers/input_layer.py @@ -16,8 +16,10 @@ from easy_rec.python.protos.feature_config_pb2 import WideOrDeep from easy_rec.python.utils import shape_utils -from easy_rec.python.compat.feature_column.feature_column import _SharedEmbeddingColumn # NOQA from easy_rec.python.compat.feature_column.feature_column_v2 import EmbeddingColumn # NOQA +from easy_rec.python.compat.feature_column.feature_column_v2 import SharedEmbeddingColumn # NOQA + +from easy_rec.python.compat.feature_column.feature_column import _SharedEmbeddingColumn # NOQA class InputLayer(object): @@ -167,7 +169,12 @@ def single_call_input_layer(self, group_columns, cols_to_output_tensors=cols_to_output_tensors, feature_name_to_output_tensors=feature_name_to_output_tensors) - embedding_reg_lst = [output_features] + # embedding_reg_lst = [output_features] + embedding_reg_lst = [] + for col, val in cols_to_output_tensors.items(): + if isinstance(col, EmbeddingColumn) or isinstance(col, + SharedEmbeddingColumn): + embedding_reg_lst.append(val) builder = feature_column._LazyBuilder(features) seq_features = [] for column in sorted(group_seq_columns, key=lambda x: x.name): @@ -227,8 +234,9 @@ def single_call_input_layer(self, group_features = [cols_to_output_tensors[x] for x in group_columns] + \ [cols_to_output_tensors[x] for x in group_seq_columns] - regularizers.apply_regularization( - self._embedding_regularizer, weights_list=embedding_reg_lst) + if embedding_reg_lst: + regularizers.apply_regularization( + self._embedding_regularizer, weights_list=embedding_reg_lst) return concat_features, group_features def get_wide_deep_dict(self): diff --git a/easy_rec/version.py b/easy_rec/version.py index 8127003c0..6e00ca21f 100644 --- a/easy_rec/version.py +++ b/easy_rec/version.py @@ -1,3 +1,3 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -__version__ = '0.6.0' +__version__ = '0.6.1'