Skip to content

Commit

Permalink
[bugfix]: fix bug of softmax_loss_with_negative_mining loss (#237)
Browse files Browse the repository at this point in the history
* [bugfix]: fix bug of softmax_loss_with_negative_mining loss
  • Loading branch information
yangxudong authored Jul 18, 2022
1 parent e604fcc commit e6622a2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
53 changes: 27 additions & 26 deletions easy_rec/python/loss/softmax_loss_with_negative_mining.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf

from easy_rec.python.utils.shape_utils import get_shape_list

if tf.__version__ >= '2.0':
tf = tf.compat.v1

Expand Down Expand Up @@ -65,33 +63,36 @@ def softmax_loss_with_negative_mining(user_emb,
Return:
support vector guided softmax loss of positive labels
"""
batch_size = get_shape_list(item_emb)[0]
assert 0 < num_negative_samples < batch_size, '`num_negative_samples` should be in range [1, batch_size)'
assert 0 < num_negative_samples, '`num_negative_samples` should be greater than 0'

if not embed_normed:
user_emb = tf.nn.l2_normalize(user_emb, axis=-1)
item_emb = tf.nn.l2_normalize(item_emb, axis=-1)
batch_size = tf.shape(item_emb)[0]
is_valid = tf.assert_less(num_negative_samples, batch_size,
message='`num_negative_samples` should be less than batch_size')
with tf.control_dependencies([is_valid]):
if not embed_normed:
user_emb = tf.nn.l2_normalize(user_emb, axis=-1)
item_emb = tf.nn.l2_normalize(item_emb, axis=-1)

vectors = [item_emb]
for i in range(num_negative_samples):
shift = tf.random_uniform([], 1, batch_size, dtype=tf.int32)
neg_item_emb = tf.roll(item_emb, shift, axis=0)
vectors.append(neg_item_emb)
# all_embeddings's shape: (batch_size, num_negative_samples + 1, vec_dim)
all_embeddings = tf.stack(vectors, axis=1)
vectors = [item_emb]
for i in range(num_negative_samples):
shift = tf.random_uniform([], 1, batch_size, dtype=tf.int32)
neg_item_emb = tf.roll(item_emb, shift, axis=0)
vectors.append(neg_item_emb)
# all_embeddings's shape: (batch_size, num_negative_samples + 1, vec_dim)
all_embeddings = tf.stack(vectors, axis=1)

mask = tf.greater(labels, 0)
mask_user_emb = tf.boolean_mask(user_emb, mask)
mask_item_emb = tf.boolean_mask(all_embeddings, mask)
if isinstance(weights, tf.Tensor):
weights = tf.boolean_mask(weights, mask)
mask = tf.greater(labels, 0)
mask_user_emb = tf.boolean_mask(user_emb, mask)
mask_item_emb = tf.boolean_mask(all_embeddings, mask)
if isinstance(weights, tf.Tensor):
weights = tf.boolean_mask(weights, mask)

# sim_scores's shape: (num_of_pos_label_in_batch_size, num_negative_samples + 1)
sim_scores = tf.keras.backend.batch_dot(
mask_user_emb, mask_item_emb, axes=(1, 2))
pos_score = tf.slice(sim_scores, [0, 0], [-1, 1])
neg_scores = tf.slice(sim_scores, [0, 1], [-1, -1])
# sim_scores's shape: (num_of_pos_label_in_batch_size, num_negative_samples + 1)
sim_scores = tf.keras.backend.batch_dot(
mask_user_emb, mask_item_emb, axes=(1, 2))
pos_score = tf.slice(sim_scores, [0, 0], [-1, 1])
neg_scores = tf.slice(sim_scores, [0, 1], [-1, -1])

loss = support_vector_guided_softmax_loss(
pos_score, neg_scores, margin=margin, t=t, smooth=gamma, weights=weights)
loss = support_vector_guided_softmax_loss(
pos_score, neg_scores, margin=margin, t=t, smooth=gamma, weights=weights)
return loss
11 changes: 6 additions & 5 deletions easy_rec/python/test/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ def test_f1_reweighted_loss(self):

def test_softmax_loss_with_negative_mining(self):
print('test_softmax_loss_with_negative_mining')
user_emb = tf.constant([[0.1, 0.5, 0.3], [0.8, -0.1, 0.3], [0.28, 0.3, 0.9],
[0.37, 0.45, 0.93], [-0.7, 0.15, 0.03],
[0.18, 0.9, -0.3]])
user_emb = tf.constant([[0.1, 0.5, 0.3], [0.8, -0.1, 0.3],
[0.28, 0.3, 0.9], [0.37, 0.45, 0.93],
[-0.7, 0.15, 0.03], [0.18, 0.9, -0.3]])
item_emb = tf.constant([[0.1, -0.5, 0.3], [0.8, -0.31, 0.3],
[0.7, -0.45, 0.15], [0.08, -0.31, -0.9],
[-0.7, 0.85, 0.03], [0.18, 0.89, -0.3]])

label = tf.constant([1, 1, 0, 0, 1, 1])
tf.random.set_random_seed(1)
loss = softmax_loss_with_negative_mining(
user_emb, item_emb, label, num_negative_samples=1)
user_emb, item_emb, label, num_negative_samples=2)
with self.test_session() as sess:
loss_val = sess.run(loss)
self.assertAlmostEqual(loss_val, 0.5240243, delta=1e-5)
self.assertAlmostEqual(loss_val, 0.76977473, delta=1e-5)

def test_circle_loss(self):
print('test_circle_loss')
Expand Down

0 comments on commit e6622a2

Please sign in to comment.