diff --git a/easy_rec/python/layers/fm.py b/easy_rec/python/layers/fm.py index c638456a4..1929e00aa 100644 --- a/easy_rec/python/layers/fm.py +++ b/easy_rec/python/layers/fm.py @@ -19,8 +19,7 @@ def __init__(self, name='fm'): def __call__(self, fm_fea): with tf.name_scope(self._name): - fm_feas = tf.concat(fm_fea, axis=1) - fm_feas = tf.expand_dims(fm_feas, axis=1) + fm_feas = tf.stack(fm_fea, axis=1) sum_square = tf.square(tf.reduce_sum(fm_feas, 1)) square_sum = tf.reduce_sum(tf.square(fm_feas), 1) y_v = 0.5 * tf.subtract(sum_square, square_sum)