diff --git a/easy_rec/python/loss/softmax_loss_with_negative_mining.py b/easy_rec/python/loss/softmax_loss_with_negative_mining.py index 417aad527..b3b7210bb 100644 --- a/easy_rec/python/loss/softmax_loss_with_negative_mining.py +++ b/easy_rec/python/loss/softmax_loss_with_negative_mining.py @@ -2,8 +2,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import tensorflow as tf -from easy_rec.python.utils.shape_utils import get_shape_list - if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -65,33 +63,36 @@ def softmax_loss_with_negative_mining(user_emb, Return: support vector guided softmax loss of positive labels """ - batch_size = get_shape_list(item_emb)[0] - assert 0 < num_negative_samples < batch_size, '`num_negative_samples` should be in range [1, batch_size)' + assert 0 < num_negative_samples, '`num_negative_samples` should be greater than 0' - if not embed_normed: - user_emb = tf.nn.l2_normalize(user_emb, axis=-1) - item_emb = tf.nn.l2_normalize(item_emb, axis=-1) + batch_size = tf.shape(item_emb)[0] + is_valid = tf.assert_less(num_negative_samples, batch_size, + message='`num_negative_samples` should be less than batch_size') + with tf.control_dependencies([is_valid]): + if not embed_normed: + user_emb = tf.nn.l2_normalize(user_emb, axis=-1) + item_emb = tf.nn.l2_normalize(item_emb, axis=-1) - vectors = [item_emb] - for i in range(num_negative_samples): - shift = tf.random_uniform([], 1, batch_size, dtype=tf.int32) - neg_item_emb = tf.roll(item_emb, shift, axis=0) - vectors.append(neg_item_emb) - # all_embeddings's shape: (batch_size, num_negative_samples + 1, vec_dim) - all_embeddings = tf.stack(vectors, axis=1) + vectors = [item_emb] + for i in range(num_negative_samples): + shift = tf.random_uniform([], 1, batch_size, dtype=tf.int32) + neg_item_emb = tf.roll(item_emb, shift, axis=0) + vectors.append(neg_item_emb) + # all_embeddings's shape: (batch_size, num_negative_samples + 1, vec_dim) + all_embeddings = tf.stack(vectors, axis=1) - mask = tf.greater(labels, 0) - mask_user_emb = tf.boolean_mask(user_emb, mask) - mask_item_emb = tf.boolean_mask(all_embeddings, mask) - if isinstance(weights, tf.Tensor): - weights = tf.boolean_mask(weights, mask) + mask = tf.greater(labels, 0) + mask_user_emb = tf.boolean_mask(user_emb, mask) + mask_item_emb = tf.boolean_mask(all_embeddings, mask) + if isinstance(weights, tf.Tensor): + weights = tf.boolean_mask(weights, mask) - # sim_scores's shape: (num_of_pos_label_in_batch_size, num_negative_samples + 1) - sim_scores = tf.keras.backend.batch_dot( - mask_user_emb, mask_item_emb, axes=(1, 2)) - pos_score = tf.slice(sim_scores, [0, 0], [-1, 1]) - neg_scores = tf.slice(sim_scores, [0, 1], [-1, -1]) + # sim_scores's shape: (num_of_pos_label_in_batch_size, num_negative_samples + 1) + sim_scores = tf.keras.backend.batch_dot( + mask_user_emb, mask_item_emb, axes=(1, 2)) + pos_score = tf.slice(sim_scores, [0, 0], [-1, 1]) + neg_scores = tf.slice(sim_scores, [0, 1], [-1, -1]) - loss = support_vector_guided_softmax_loss( - pos_score, neg_scores, margin=margin, t=t, smooth=gamma, weights=weights) + loss = support_vector_guided_softmax_loss( + pos_score, neg_scores, margin=margin, t=t, smooth=gamma, weights=weights) return loss diff --git a/easy_rec/python/test/loss_test.py b/easy_rec/python/test/loss_test.py index ddd2dc842..27248fe31 100644 --- a/easy_rec/python/test/loss_test.py +++ b/easy_rec/python/test/loss_test.py @@ -26,19 +26,20 @@ def test_f1_reweighted_loss(self): def test_softmax_loss_with_negative_mining(self): print('test_softmax_loss_with_negative_mining') - user_emb = tf.constant([[0.1, 0.5, 0.3], [0.8, -0.1, 0.3], [0.28, 0.3, 0.9], - [0.37, 0.45, 0.93], [-0.7, 0.15, 0.03], - [0.18, 0.9, -0.3]]) + user_emb = tf.constant([[0.1, 0.5, 0.3], [0.8, -0.1, 0.3], + [0.28, 0.3, 0.9], [0.37, 0.45, 0.93], + [-0.7, 0.15, 0.03], [0.18, 0.9, -0.3]]) item_emb = tf.constant([[0.1, -0.5, 0.3], [0.8, -0.31, 0.3], [0.7, -0.45, 0.15], [0.08, -0.31, -0.9], [-0.7, 0.85, 0.03], [0.18, 0.89, -0.3]]) label = tf.constant([1, 1, 0, 0, 1, 1]) + tf.random.set_random_seed(1) loss = softmax_loss_with_negative_mining( - user_emb, item_emb, label, num_negative_samples=1) + user_emb, item_emb, label, num_negative_samples=2) with self.test_session() as sess: loss_val = sess.run(loss) - self.assertAlmostEqual(loss_val, 0.5240243, delta=1e-5) + self.assertAlmostEqual(loss_val, 0.76977473, delta=1e-5) def test_circle_loss(self): print('test_circle_loss')