From d1bbbec38ad43c167b23673c617fa3b8456185c5 Mon Sep 17 00:00:00 2001 From: Julian Qian Date: Tue, 9 Apr 2024 20:18:51 -0700 Subject: [PATCH] enable ShadowVariable look up for safe_embedding_lookup_sparse and support raggedtensor --- .gitignore | 8 +- .../python/keras/layers/embedding.py | 21 +- .../dynamic_embedding_ops_test.py | 282 +++++++--- .../kernel_tests/ragged_embedding_ops_test.py | 27 + .../kernel_tests/shadow_embedding_ops_test.py | 47 ++ .../python/ops/dynamic_embedding_ops.py | 68 ++- .../python/ops/ragged_embedding_ops.py | 509 ++++++++++++++++++ .../python/ops/shadow_embedding_ops.py | 59 +- 8 files changed, 881 insertions(+), 140 deletions(-) create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ragged_embedding_ops_test.py create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/python/ops/ragged_embedding_ops.py diff --git a/.gitignore b/.gitignore index 90ed219bd..4cf744f93 100644 --- a/.gitignore +++ b/.gitignore @@ -12,10 +12,10 @@ artifacts/# File patterns to ignore; see `git help ignore` for more information. # Lines that start with '#' are comments. *.whl -/bazel-bin/ -/bazel-out/ -/bazel-recommenders-addons/ -/bazel-testlogs/ +bazel-bin +bazel-out +bazel-recommenders-addons +bazel-testlogs /tensorflow_recommenders_addons/dynamic_embedding/core/*.so bazel-genfiles diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py index 7131be263..cfd7e5ac6 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py @@ -18,16 +18,12 @@ See [Sparse Domain Isolation](https://github.com/tensorflow/community/pull/237) """ -import pickle - import tensorflow as tf from tensorflow.python.eager import context -from tensorflow.python.ops import init_ops from tensorflow_recommenders_addons import dynamic_embedding as de from tensorflow_recommenders_addons.dynamic_embedding.python.ops import dynamic_embedding_variable as devar -from tensorflow.python.distribute import distribute_lib from tensorflow.python.keras.utils import tf_utils try: # tf version >= 2.14.0 from tensorflow.python.distribute import distribute_lib as distribute_ctx @@ -268,20 +264,9 @@ def call(self, ids): Returns: A embedding output with shape (shape(ids), embedding_size). """ - ids = tf.convert_to_tensor(ids) - input_shape = tf.shape(ids) - embeddings_shape = tf.concat([input_shape, [self.embedding_size]], 0) - ids_flat = tf.reshape(ids, (-1,)) - if self.with_unique: - with tf.name_scope(self.name + "/EmbeddingWithUnique"): - unique_ids, idx = tf.unique(ids_flat) - unique_embeddings = de.shadow_ops.embedding_lookup( - self.shadow, unique_ids) - embeddings_flat = tf.gather(unique_embeddings, idx) - else: - embeddings_flat = de.shadow_ops.embedding_lookup(self.shadow, ids_flat) - embeddings = tf.reshape(embeddings_flat, embeddings_shape) - return embeddings + return de.shadow_ops.embedding_lookup_unique(self.shadow, ids, + self.embedding_size, + self.with_unique, self.name) def get_config(self): _initializer = self.params.initializer diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py index 48e58edc3..a777551b3 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py @@ -23,9 +23,14 @@ import math import numpy as np import os + import tensorflow as tf +from absl.testing import parameterized +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow_recommenders_addons import dynamic_embedding as de +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.ragged_embedding_ops import embedding_lookup_sparse, \ + safe_embedding_lookup_sparse try: from tensorflow.python.keras.initializers import initializers_v2 as kinit2 @@ -125,7 +130,7 @@ def embedding_result(params, id_vals, weight_vals=None): return values, weights, weights_squared -def ids_and_weights_2d(embed_dim=4): +def _ids_and_weights_2d(embed_dim=4, ragged=False): # Each row demonstrates a test case: # Row 0: multiple valid ids, 1 invalid id, weighted mean # Row 1: all ids are invalid (leaving no valid ids after pruning) @@ -148,11 +153,14 @@ def ids_and_weights_2d(embed_dim=4): constant_op.constant(weights, dtypes.float32), constant_op.constant(shape, dtypes.int64), ) - + if ragged: + sparse_ids = ragged_tensor.RaggedTensor.from_sparse(sparse_ids) + sparse_weights = ragged_tensor.RaggedTensor.from_sparse(sparse_weights) return sparse_ids, sparse_weights -def ids_and_weights_3d(embed_dim=4): +def _ids_and_weights_3d( + embed_dim=4) -> (sparse_tensor.SparseTensor, sparse_tensor.SparseTensor): # Each (2-D) index demonstrates a test case: # Index 0, 0: multiple valid ids, 1 invalid id, weighted mean # Index 0, 1: all ids are invalid (leaving no valid ids after pruning) @@ -743,10 +751,15 @@ def test_embedding_lookup_unique(self): np.testing.assert_almost_equal(embedded_np, embedded_de) -@test_util.deprecated_graph_mode_only -class EmbeddingLookupSparseTest(test.TestCase): - - def _random_ids_and_weights(self, batch_size, vocab_size, k_type, d_type): +@test_util.run_all_in_graph_and_eager_modes +class EmbeddingLookupSparseTest(test.TestCase, parameterized.TestCase): + + def _random_ids_and_weights(self, + batch_size, + vocab_size, + k_type, + d_type, + ragged=False): max_val_per_entry = 6 vals_per_batch_entry = np.random.randint(1, max_val_per_entry, @@ -775,7 +788,9 @@ def _random_ids_and_weights(self, batch_size, vocab_size, k_type, d_type): constant_op.constant(weights, d_type), constant_op.constant(shape, dtypes.int64), ) - + if ragged: + sp_ids = ragged_tensor.RaggedTensor.from_sparse(sp_ids) + sp_weights = ragged_tensor.RaggedTensor.from_sparse(sp_weights) return sp_ids, sp_weights, ids, weights, vals_per_batch_entry def _group_by_batch_entry(self, vals, vals_per_batch_entry): @@ -786,8 +801,8 @@ def _group_by_batch_entry(self, vals, vals_per_batch_entry): index += num_val return grouped_vals - def test_embedding_lookup_sparse(self): - + @parameterized.parameters(itertools.product([True, False])) + def test_embedding_lookup_sparse(self, ragged): var_id = 0 for ( num_shards, @@ -821,7 +836,8 @@ def test_embedding_lookup_sparse(self): ids, weights, vals_per_batch_entry, - ) = self._random_ids_and_weights(batch_size, vocab_size, k_dtype, d_dtype) + ) = self._random_ids_and_weights(batch_size, vocab_size, k_dtype, d_dtype, + ragged) grouped_ids = self._group_by_batch_entry(ids, vals_per_batch_entry) grouped_weights = self._group_by_batch_entry(weights, @@ -843,18 +859,30 @@ def test_embedding_lookup_sparse(self): random_init = params.lookup(ids) init_op = params.upsert(ids, random_init) self.evaluate(init_op) - np_params = random_init.eval() + np_params = random_init.numpy() if context.executing_eagerly( + ) else random_init.eval() grouped_params = self._group_by_batch_entry(np_params, vals_per_batch_entry) - embedding_sum = de.embedding_lookup_sparse( - params, - sp_ids, - None if ignore_weights else sp_weights, - combiner=combiner, - ) + if context.executing_eagerly(): + params = de.shadow_ops.ShadowVariable(params) + if ragged: + embedding_sum = embedding_lookup_sparse( + params, + sp_ids, + None if ignore_weights else sp_weights, + combiner=combiner, + ) + else: + embedding_sum = de.embedding_lookup_sparse( + params, + sp_ids, + None if ignore_weights else sp_weights, + combiner=combiner, + ) self.assertEqual(embedding_sum.dtype, d_dtype) - tf_embedding_sum = embedding_sum.eval() + tf_embedding_sum = embedding_sum.numpy() if context.executing_eagerly( + ) else embedding_sum.eval() np_embedding_sum, np_weight_sum, np_weight_sq_sum = embedding_result( grouped_params, @@ -873,6 +901,8 @@ def test_embedding_lookup_sparse(self): self.assertAllClose(np_embedding_sum, tf_embedding_sum, rtol, atol) def test_embedding_lookup_sparse_shape_checking(self): + if context.executing_eagerly(): + self.skipTest("Skip eager test") with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): embed_dim = 4 @@ -880,7 +910,7 @@ def test_embedding_lookup_sparse_shape_checking(self): shape=[100, embed_dim], use_resource=False) embedding_weights_de = _random_weights(embed_dim=4) - sparse_ids, _ = ids_and_weights_3d(embed_dim=embed_dim) + sparse_ids, _ = _ids_and_weights_3d(embed_dim=embed_dim) embedding_lookup_base = embedding_ops.embedding_lookup_sparse( embedding_weights_nn, sparse_ids, None) @@ -890,15 +920,30 @@ def test_embedding_lookup_sparse_shape_checking(self): embedding_lookup_test.get_shape().as_list()) -@test_util.deprecated_graph_mode_only -class SafeEmbeddingLookupSparseTest(test.TestCase): - - def test_safe_embedding_lookup_sparse_return_zero_vector(self): - with self.cached_session(use_gpu=test_util.is_gpu_available(), - config=default_config): +@test_util.run_all_in_graph_and_eager_modes +class SafeEmbeddingLookupSparseTest(test.TestCase, parameterized.TestCase): + + def _get_ids_and_weights_3d(self, valid_ids): + embedding_weights = _random_weights() + sparse_ids, sparse_weights = _ids_and_weights_3d() + + # init + embedding_weights_values = embedding_weights.lookup(valid_ids) + embedding_weights_values = embedding_weights_values.numpy( + ) if context.executing_eagerly() else embedding_weights_values.eval() + self.evaluate(embedding_weights.upsert(valid_ids, embedding_weights_values)) + if context.executing_eagerly(): + embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) + return embedding_weights, embedding_weights_values, sparse_ids, sparse_weights + + @parameterized.parameters(itertools.product([True, False])) + def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged=False): + with self.session(use_gpu=test_util.is_gpu_available(), + config=default_config): dim = 4 embedding_weights = _random_weights(embed_dim=dim) - sparse_ids, sparse_weights = ids_and_weights_2d(embed_dim=dim) + sparse_ids, sparse_weights = _ids_and_weights_2d(embed_dim=dim, + ragged=ragged) valid_ids = np.array([ 0, 1, @@ -907,13 +952,24 @@ def test_safe_embedding_lookup_sparse_return_zero_vector(self): ]) # init - embedding_weights_values = embedding_weights.lookup(valid_ids).eval() + weights = embedding_weights.lookup(valid_ids) + embedding_weights_values = weights.numpy() if context.executing_eagerly( + ) else weights.eval() self.evaluate( embedding_weights.upsert(valid_ids, embedding_weights_values)) # check - embedding_lookup_result = de.safe_embedding_lookup_sparse( - embedding_weights, sparse_ids, sparse_weights).eval() + if context.executing_eagerly(): + embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) + if ragged: + embedding_lookup_result = safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights) + else: + embedding_lookup_result = de.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights) + embedding_lookup_result = embedding_lookup_result.numpy( + ) if context.executing_eagerly() else embedding_lookup_result.eval() + self.assertAllClose( embedding_lookup_result, [ @@ -927,23 +983,35 @@ def test_safe_embedding_lookup_sparse_return_zero_vector(self): ], ) - def test_safe_embedding_lookup_sparse_return_special_vector(self): + @parameterized.parameters(itertools.product([True, False])) + def test_safe_embedding_lookup_sparse_return_special_vector( + self, ragged=False): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): dim = 4 embedding_weights = _random_weights(embed_dim=dim) - sparse_ids, sparse_weights = ids_and_weights_2d(embed_dim=dim) + sparse_ids, sparse_weights = _ids_and_weights_2d(embed_dim=dim, + ragged=ragged) valid_ids = np.array([0, 1, 2, 3, -1]) # init - embedding_weights_values = embedding_weights.lookup(valid_ids).eval() + weights = embedding_weights.lookup(valid_ids) + embedding_weights_values = weights.numpy() if context.executing_eagerly( + ) else weights.eval() self.evaluate( embedding_weights.upsert(valid_ids, embedding_weights_values)) # check - embedding_lookup_result = de.safe_embedding_lookup_sparse( - embedding_weights, sparse_ids, sparse_weights, default_id=3).eval() - + if context.executing_eagerly(): + embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) + if ragged: + embedding_lookup_result = safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights, default_id=3) + else: + embedding_lookup_result = de.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights, default_id=3) + embedding_lookup_result = embedding_lookup_result.numpy( + ) if context.executing_eagerly() else embedding_lookup_result.eval() self.assertAllClose( embedding_lookup_result, [ @@ -957,21 +1025,33 @@ def test_safe_embedding_lookup_sparse_return_special_vector(self): ], ) - def test_safe_embedding_lookup_sparse_no_weights(self): + @parameterized.parameters(itertools.product([True, False])) + def test_safe_embedding_lookup_sparse_no_weights(self, ragged=False): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): dim = 4 embedding_weights = _random_weights(embed_dim=dim) - sparse_ids, sparse_weights = ids_and_weights_2d(embed_dim=dim) + sparse_ids, sparse_weights = _ids_and_weights_2d(embed_dim=dim, + ragged=ragged) valid_ids = np.array([0, 1, 2, -1]) # init - embedding_weights_values = embedding_weights.lookup(valid_ids).eval() + weights = embedding_weights.lookup(valid_ids) + embedding_weights_values = weights.numpy() if context.executing_eagerly( + ) else weights.eval() self.evaluate( embedding_weights.upsert(valid_ids, embedding_weights_values)) - embedding_lookup_result = de.safe_embedding_lookup_sparse( - embedding_weights, sparse_ids, None).eval() + if context.executing_eagerly(): + embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) + if ragged: + embedding_lookup_result = safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None) + else: + embedding_lookup_result = de.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None) + embedding_lookup_result = embedding_lookup_result.numpy( + ) if context.executing_eagerly() else embedding_lookup_result.eval() self.assertAllClose( embedding_lookup_result, @@ -985,21 +1065,33 @@ def test_safe_embedding_lookup_sparse_no_weights(self): ], ) - def test_safe_embedding_lookup_sparse_partitioned(self): + @parameterized.parameters(itertools.product([True, False])) + def test_safe_embedding_lookup_sparse_partitioned(self, ragged=False): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): dim = 4 embedding_weights = _random_weights(embed_dim=dim, num_shards=3) - sparse_ids, sparse_weights = ids_and_weights_2d(embed_dim=dim) + sparse_ids, sparse_weights = _ids_and_weights_2d(embed_dim=dim, + ragged=ragged) valid_ids = np.array([0, 1, 2, -1]) # init - embedding_weights_values = embedding_weights.lookup(valid_ids).eval() + weights = embedding_weights.lookup(valid_ids) + embedding_weights_values = weights.numpy() if context.executing_eagerly( + ) else weights.eval() self.evaluate( embedding_weights.upsert(valid_ids, embedding_weights_values)) - embedding_lookup_result = de.safe_embedding_lookup_sparse( - embedding_weights, sparse_ids, None).eval() + if context.executing_eagerly(): + embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) + if ragged: + embedding_lookup_result = safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None) + else: + embedding_lookup_result = de.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None) + embedding_lookup_result = embedding_lookup_result.numpy( + ) if context.executing_eagerly() else embedding_lookup_result.eval() self.assertAllClose( embedding_lookup_result, @@ -1013,44 +1105,58 @@ def test_safe_embedding_lookup_sparse_partitioned(self): ], ) - def test_safe_embedding_lookup_sparse_inconsistent_ids_type(self): + @parameterized.parameters(itertools.product([True, False])) + def test_safe_embedding_lookup_sparse_inconsistent_ids_type( + self, ragged=False): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): def fn(): embedding_weights = _random_weights(num_shards=3, key_dtype=dtypes.int32) - sparse_ids, sparse_weights = ids_and_weights_2d() - de.safe_embedding_lookup_sparse(embedding_weights, sparse_ids, - sparse_weights) + sparse_ids, sparse_weights = _ids_and_weights_2d(ragged=ragged) + if context.executing_eagerly(): + embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) + if ragged: + safe_embedding_lookup_sparse(embedding_weights, sparse_ids, + sparse_weights) + else: + de.safe_embedding_lookup_sparse(embedding_weights, sparse_ids, + sparse_weights) self.assertRaises(TypeError, fn) - def test_safe_embedding_lookup_sparse_inconsistent_weights_type(self): + @parameterized.parameters(itertools.product([True, False])) + def test_safe_embedding_lookup_sparse_inconsistent_weights_type( + self, ragged=False): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): def fn(): embedding_weights = _random_weights(num_shards=3, key_dtype=dtypes.half) - sparse_ids, sparse_weights = ids_and_weights_2d() - de.safe_embedding_lookup_sparse(embedding_weights, sparse_ids, - sparse_weights) + sparse_ids, sparse_weights = _ids_and_weights_2d(ragged=ragged) + if context.executing_eagerly(): + embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) + if ragged: + safe_embedding_lookup_sparse(embedding_weights, sparse_ids, + sparse_weights) + else: + de.safe_embedding_lookup_sparse(embedding_weights, sparse_ids, + sparse_weights) self.assertRaises(TypeError, fn) def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): - embedding_weights = _random_weights() - sparse_ids, sparse_weights = ids_and_weights_3d() valid_ids = np.array([0, 1, 2, -1]) - # init - embedding_weights_values = embedding_weights.lookup(valid_ids).eval() - self.evaluate( - embedding_weights.upsert(valid_ids, embedding_weights_values)) + embedding_weights, embedding_weights_values, sparse_ids, sparse_weights = self._get_ids_and_weights_3d( + valid_ids) embedding_lookup_result = de.safe_embedding_lookup_sparse( - embedding_weights, sparse_ids, sparse_weights).eval() + embedding_weights, sparse_ids, sparse_weights) + embedding_lookup_result = embedding_lookup_result.numpy( + ) if context.executing_eagerly() else embedding_lookup_result.eval() self.assertAllClose( embedding_lookup_result, @@ -1069,18 +1175,12 @@ def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): - embedding_weights = _random_weights() - sparse_ids, sparse_weights = ids_and_weights_3d() - valid_ids = np.array([0, 1, 2, 3, -1]) - - # init - embedding_weights_values = embedding_weights.lookup(valid_ids).eval() - self.evaluate( - embedding_weights.upsert(valid_ids, embedding_weights_values)) - + embedding_weights, embedding_weights_values, sparse_ids, sparse_weights = self._get_ids_and_weights_3d( + np.array([0, 1, 2, 3, -1])) embedding_lookup_result = de.safe_embedding_lookup_sparse( - embedding_weights, sparse_ids, sparse_weights, default_id=3).eval() - + embedding_weights, sparse_ids, sparse_weights, default_id=3) + embedding_lookup_result = embedding_lookup_result.numpy( + ) if context.executing_eagerly() else embedding_lookup_result.eval() self.assertAllClose( embedding_lookup_result, [ @@ -1102,17 +1202,13 @@ def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): def test_safe_embedding_lookup_sparse_3d_no_weights(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): - embedding_weights = _random_weights() - sparse_ids, _ = ids_and_weights_3d() valid_ids = np.array([0, 1, 2, -1]) - # init - embedding_weights_values = embedding_weights.lookup(valid_ids).eval() - self.evaluate( - embedding_weights.upsert(valid_ids, embedding_weights_values)) - + embedding_weights, embedding_weights_values, sparse_ids, _ = self._get_ids_and_weights_3d( + valid_ids) embedding_lookup_result = de.safe_embedding_lookup_sparse( - embedding_weights, sparse_ids, None).eval() - + embedding_weights, sparse_ids, None) + embedding_lookup_result = embedding_lookup_result.numpy( + ) if context.executing_eagerly() else embedding_lookup_result.eval() self.assertAllClose( embedding_lookup_result, [ @@ -1135,16 +1231,21 @@ def test_safe_embedding_lookup_sparse_3d_partitioned(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): embedding_weights = _random_weights(num_shards=3) - sparse_ids, _ = ids_and_weights_3d() + sparse_ids, _ = _ids_and_weights_3d() valid_ids = np.array([0, 1, 2, -1]) # init - embedding_weights_values = embedding_weights.lookup(valid_ids).eval() + embedding_weights_values = embedding_weights.lookup(valid_ids) + embedding_weights_values = embedding_weights_values.numpy( + ) if context.executing_eagerly() else embedding_weights_values.eval() self.evaluate( embedding_weights.upsert(valid_ids, embedding_weights_values)) - + if context.executing_eagerly(): + embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) embedding_lookup_result = de.safe_embedding_lookup_sparse( - embedding_weights, sparse_ids, None).eval() + embedding_weights, sparse_ids, None) + embedding_lookup_result = embedding_lookup_result.numpy( + ) if context.executing_eagerly() else embedding_lookup_result.eval() self.assertAllClose( embedding_lookup_result, @@ -1214,10 +1315,15 @@ def test_safe_embedding_lookup_sparse_with_initializer(self): constant_op.constant(ids, dtypes.int64), constant_op.constant(dense_shape, dtypes.int64), ) + if context.executing_eagerly(): + embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) + vals_op = de.safe_embedding_lookup_sparse(embedding_weights, sparse_ids, None, - combiner="mean").eval() + combiner="mean") + vals_op = vals_op.numpy() if context.executing_eagerly( + ) else vals_op.eval() mean = self.evaluate(math_ops.reduce_mean(vals_op)) stddev = self.evaluate(math_ops.reduce_std(vals_op)) @@ -1228,6 +1334,8 @@ def test_safe_embedding_lookup_sparse_with_initializer(self): self.assertAllClose(target_stddev, stddev, rtol, atol) def test_safe_embedding_lookup_sparse_shape_checking(self): + if context.executing_eagerly(): + self.skipTest("Skip eager test") with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): embed_dim = 4 @@ -1235,7 +1343,7 @@ def test_safe_embedding_lookup_sparse_shape_checking(self): shape=[100, embed_dim], use_resource=False) embedding_weights_de = _random_weights(embed_dim=4) - sparse_ids, _ = ids_and_weights_3d(embed_dim=embed_dim) + sparse_ids, _ = _ids_and_weights_3d(embed_dim=embed_dim) embedding_lookup_base = embedding_ops.safe_embedding_lookup_sparse( embedding_weights_nn, sparse_ids, None) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ragged_embedding_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ragged_embedding_ops_test.py new file mode 100644 index 000000000..23e8a155e --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/ragged_embedding_ops_test.py @@ -0,0 +1,27 @@ +import tensorflow as tf +import unittest + +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.ragged_embedding_ops import _fill_empty_rows + + +class TestFillEmptyRows(unittest.TestCase): + + def test_fill_empty_rows(self): + test_ragged_tensor = tf.ragged.constant([[1, 2, 3], [], [4], [], [5, 6]], + dtype=tf.int32) + default_id = 0 + + filled_ragged_tensor, is_row_empty = _fill_empty_rows( + test_ragged_tensor, default_id) + + expected_filled = tf.ragged.constant([[1, 2, 3], [0], [4], [0], [5, 6]], + dtype=tf.int32) + expected_empty = tf.constant([False, True, False, True, False]) + + self.assertTrue( + tf.reduce_all(filled_ragged_tensor.to_tensor() == + expected_filled.to_tensor()).numpy(), + "Filled tensors do not match") + self.assertTrue( + tf.reduce_all(is_row_empty == expected_empty).numpy(), + "Empty row flags do not match") diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/shadow_embedding_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/shadow_embedding_ops_test.py index 766e29df1..562398c48 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/shadow_embedding_ops_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/shadow_embedding_ops_test.py @@ -154,6 +154,29 @@ def test_lookup(self): emb = self.evaluate(de.shadow_ops.embedding_lookup(shadow_var, ext_ids)) self.assertAllEqual(exp_values, emb) + def test_safe_embedding_lookup_sparse(self): + if not context.executing_eagerly(): + self.skipTest('Only test in eager mode.') + with self.session(use_gpu=test_util.is_gpu_available(), + config=default_config): + var, shadow_var = _get_sparse_variable('tk049', dim=2) + self.evaluate(variables.global_variables_initializer()) + ids = constant_op.constant([2, 5], dtype=dtypes.int64) + values = array_ops.ones((2, 2), dtype=np.float32) + self.evaluate( + var.upsert(ids, ops.convert_to_tensor(values, dtype=dtypes.float32))) + + sp_ids = constant_op.constant([[0, 2], [1, 5]], dtype=dtypes.int64) + sp_weights = constant_op.constant([2, 5], dtype=dtypes.int64) + dense_shape = constant_op.constant([2, 6], dtype=dtypes.int64) + sparse_tensor = tf.sparse.SparseTensor(indices=sp_ids, + values=sp_weights, + dense_shape=dense_shape) + + emb = self.evaluate( + de.safe_embedding_lookup_sparse(shadow_var, sparse_tensor)) + self.assertAllEqual(emb, values) + def test_update_with_optimizer_v1(self): if not context.executing_eagerly(): self.skipTest('Only test when eagerly.') @@ -410,6 +433,30 @@ def test_embedding_lookup(self): constant_op.constant([[2.2, 2.2], [3.3, 3.3], [0.1, 0.1]], dtype=dtypes.float32)) + def test_embedding_lookup_unique(self): + if not context.executing_eagerly(): + self.skipTest('Only test in eager mode.') + + params = de.get_variable('pn012', dim=2, initializer=0.1) + params.upsert( + constant_op.constant([1, 2, 3], dtype=dtypes.int64), + constant_op.constant([[1., 1.], [2., 2.], [3., 3.]], + dtype=dtypes.float32)) + shadow = de.shadow_ops.ShadowVariable(params) + # [[2, 3], [4, 5, 1]] + ragged_ids = tf.RaggedTensor.from_row_splits(values=tf.constant( + [2, 3, 4, 5, 1], dtype=tf.int64), + row_splits=[0, 2, 5]) + val_ragged = de.shadow_ops.embedding_lookup_unique(shadow, ragged_ids, 2, + True) + expected_output = tf.RaggedTensor.from_row_splits(values=[[2., 2.], + [3., 3.], + [0.1, 0.1], + [0.1, 0.1], + [1., 1.]], + row_splits=[0, 2, 5]) + self.assertAllEqual(val_ragged, expected_output) + def test_get_size(self): if not context.executing_eagerly(): self.skipTest('Only test in eager mode.') diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py index eedfa3aa9..8a4c7b68a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py @@ -61,7 +61,7 @@ from tensorflow.python.trackable import data_structures except: from tensorflow.python.training.tracking import data_structures -from tensorflow.python.util import compat +from tensorflow.python.util import compat, dispatch from tensorflow.python.util.tf_export import tf_export from tensorflow.python.keras.utils import tf_utils @@ -773,7 +773,7 @@ def embedding_lookup_sparse( Args: params: A single `dynamic_embedding.Variable` instance representing - the complete embedding tensor. + the complete embedding tensor or a `ShadowVariable` instance. sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size and M is arbitrary. sp_weights: either a `SparseTensor` of float / double weights, or `None` to @@ -855,15 +855,22 @@ def embedding_lookup_sparse( ids = sp_ids.values ids, idx = array_ops.unique(ids) - - embeddings, trainable_ = embedding_lookup( - params, - ids, - name=name + '/embedding_lookup', - partition_strategy=partition_strategy, - max_norm=max_norm, - return_trainable=True, - ) + if isinstance(params, de.shadow_ops.ShadowVariable): + embeddings = de.shadow_ops.embedding_lookup( + params, + ids, + name=name + '/embedding_lookup', + ) + trainable_ = params + else: + embeddings, trainable_ = embedding_lookup( + params, + ids, + name=name + '/embedding_lookup', + partition_strategy=partition_strategy, + max_norm=max_norm, + return_trainable=True, + ) if embeddings.dtype in (dtypes.float16, dtypes.bfloat16): embeddings = math_ops.cast(embeddings, dtypes.float32) if not ignore_weights: @@ -928,6 +935,24 @@ def embedding_lookup_sparse( return (embeddings, trainable_) if return_trainable else embeddings +def verify_embedding_weights(embedding_weights, + sparse_ids, + sparse_weights=None): + if embedding_weights is None: + raise ValueError("Missing embedding_weights %s." % embedding_weights) + + if embedding_weights.key_dtype != sparse_ids.dtype: + raise TypeError( + "embedding_weights.key_dtype should be same with sparse_ids.dtype: " + "{} vs. {}".format(embedding_weights.key_dtype, sparse_ids.dtype)) + + weights_dtype = sparse_weights.dtype if sparse_weights is not None else None + if weights_dtype and embedding_weights.value_dtype != weights_dtype: + raise TypeError( + "embedding_weights.value_dtype should be same with sparse_weights.dtype" + ": {} vs. {}".format(embedding_weights.value_dtype, weights_dtype)) + + def safe_embedding_lookup_sparse( embedding_weights, sparse_ids, @@ -953,7 +978,7 @@ def safe_embedding_lookup_sparse( Args: embedding_weights: A single `dynamic_embedding.Variable` instance - representing the complete embedding tensor. + representing the complete embedding tensor or a single `ShadowVariable` instance. sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the ids. `d_0` is typically batch size. sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing @@ -980,22 +1005,15 @@ def safe_embedding_lookup_sparse( Raises: ValueError: if `embedding_weights` is empty. """ - if embedding_weights is None: - raise ValueError("Missing embedding_weights %s." % embedding_weights) - - if embedding_weights.key_dtype != sparse_ids.dtype: - raise TypeError( - "embedding_weights.key_dtype should be same with sparse_ids.dtype: " - "{} vs. {}".format(embedding_weights.key_dtype, sparse_ids.dtype)) - - weights_dtype = sparse_weights.dtype if sparse_weights is not None else None - if weights_dtype and embedding_weights.value_dtype != weights_dtype: - raise TypeError( - "embedding_weights.value_dtype should be same with sparse_weights.dtype" - ": {} vs. {}".format(embedding_weights.value_dtype, weights_dtype)) + if isinstance(embedding_weights, de.shadow_ops.ShadowVariable): + verify_embedding_weights(embedding_weights.params, sparse_ids, + sparse_weights) + else: + verify_embedding_weights(embedding_weights, sparse_ids, sparse_weights) scope = variable_scope.get_variable_scope() full_name = scope.name + "/" + name if scope.name else name + with ops.name_scope(full_name + "/"): # Reshape higher-rank sparse ids and weights to linear segment ids. original_shape = sparse_ids.dense_shape diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/ragged_embedding_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/ragged_embedding_ops.py new file mode 100644 index 000000000..4ecf76f29 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/ragged_embedding_ops.py @@ -0,0 +1,509 @@ +import tensorflow as tf +from tensorflow.python.framework import dtypes, ops +from tensorflow.python.ops import resource_variable_ops, array_ops, math_ops, gen_ragged_array_ops, gen_math_ops +from tensorflow.python.ops.bincount_ops import validate_dense_weights +from tensorflow.python.ops.ragged import ragged_tensor, ragged_array_ops +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import ShadowVariable + +from tensorflow_recommenders_addons import dynamic_embedding as de +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import verify_embedding_weights + + +def _bincount(arr, + weights=None, + minlength=None, + maxlength=None, + dtype=dtypes.int32, + name=None, + axis=None, + binary_output=False): + + name = "bincount" if name is None else name + with ops.name_scope(name): + arr = tf.convert_to_tensor(arr, name="arr") + if weights is not None: + weights = tf.convert_to_tensor(weights, name="weights") + + if weights is not None and binary_output: + raise ValueError("Arguments `binary_output` and `weights` are mutually " + "exclusive. Please specify only one.") + + if not arr.dtype.is_integer: + arr = math_ops.cast(arr, dtypes.int32) + if axis is None: + axis = 0 + + if axis not in [0, -1]: + raise ValueError(f"Unsupported value for argument axis={axis}. Only 0 and" + " -1 are currently supported.") + + array_is_nonempty = array_ops.size(arr) > 0 + output_size = math_ops.cast(array_is_nonempty, + arr.dtype) * (math_ops.reduce_max(arr) + 1) + if minlength is not None: + minlength = ops.convert_to_tensor(minlength, + name="minlength", + dtype=arr.dtype) + output_size = gen_math_ops.maximum(minlength, output_size) + if maxlength is not None: + maxlength = ops.convert_to_tensor(maxlength, + name="maxlength", + dtype=arr.dtype) + output_size = gen_math_ops.minimum(maxlength, output_size) + + if axis == 0: + if weights is not None: + weights = array_ops.reshape(weights, [-1]) + arr = array_ops.reshape(arr, [-1]) + + weights = validate_dense_weights(arr, weights, dtype) + return gen_math_ops.dense_bincount(input=arr, + size=output_size, + weights=weights, + binary_output=binary_output) + + +# # for compatibility with tf 2.11 +def _ragged_fill_empty_rows(value_rowids, values, nrows, default_value): + # Convert default_value to the correct dtype + default_value = tf.convert_to_tensor(default_value, dtype=values.dtype) + + # Determine the total number of rows and the maximum row index in value_rowids + max_row_index = tf.reduce_max(value_rowids) + total_rows = tf.maximum(nrows, max_row_index + 1) + + # Create a tensor of row lengths + row_lengths = _bincount(value_rowids, + minlength=total_rows, + maxlength=total_rows, + dtype=value_rowids.dtype) + + # Identify empty rows + empty_row_indicator = tf.equal(row_lengths, 0) + + # Generate default values for empty rows + num_empty_rows = tf.reduce_sum(tf.cast(empty_row_indicator, tf.int32)) + default_values = tf.fill([num_empty_rows], default_value) + + # Create new value_rowids for empty rows + empty_rows = tf.where(empty_row_indicator) + new_value_rowids = tf.repeat(empty_rows, repeats=1) + + # Concatenate original and default values and row ids + final_values = tf.concat([values, default_values], axis=0) + final_value_rowids = tf.concat([value_rowids, new_value_rowids], axis=0) + + # Sort by rowids to maintain ragged tensor structure + sorted_indices = tf.argsort(final_value_rowids) + sorted_values = tf.gather(final_values, sorted_indices) + sorted_value_rowids = tf.gather(final_value_rowids, sorted_indices) + + return sorted_value_rowids, sorted_values, empty_row_indicator + + +# # for compatibility with tf 2.11 +def _fill_empty_rows(ragged_input, default_value, name=None): + try: + # if ragged_array_ops.fill_empty_rows is available, use it + return ragged_array_ops.fill_empty_rows(ragged_input, + default_value, + name=name) + except AttributeError: + if not isinstance(ragged_input, tf.RaggedTensor): + raise TypeError("ragged_input must be a RaggedTensor, got %s" % + type(ragged_input)) + default_value_tensor = tf.convert_to_tensor(default_value, + dtype=ragged_input.dtype) + + output_value_rowids, output_values, empty_row_indicator = _ragged_fill_empty_rows( + ragged_input.value_rowids(), ragged_input.values, ragged_input.nrows(), + default_value_tensor) + + ragged_ordered_output = tf.RaggedTensor.from_value_rowids( + values=output_values, + value_rowids=output_value_rowids, + nrows=ragged_input.nrows(), + validate=False) + return ragged_ordered_output, empty_row_indicator + + +def _embedding_lookup_sparse_impl( + params, + segment_ids, + sp_weights, + ids, + combiner, + ignore_weights, + name, +): + """Implementation of sparse embedding aggregation.""" + need_sparse_segment_gradient = False + # Ensure we can query the devices below. + segment_ids = ops.convert_to_tensor(segment_ids, name="segment_ids") + + ids, idx = array_ops.unique(ids) + if isinstance(params, de.shadow_ops.ShadowVariable): + embeddings = de.shadow_ops.embedding_lookup(params, ids) + else: + embeddings = de.embedding_lookup(params, ids) + + if not ignore_weights: + if segment_ids.dtype != dtypes.int32: + segment_ids = math_ops.cast(segment_ids, dtypes.int32) + + weights = sp_weights.values + embeddings = array_ops.gather(embeddings, idx) + + original_dtype = embeddings.dtype + if embeddings.dtype in (dtypes.float16, dtypes.bfloat16): + # Cast low-precision embeddings to float32 during the computation to + # avoid numerical issues. + embeddings = math_ops.cast(embeddings, dtypes.float32) + if weights.dtype != embeddings.dtype: + weights = math_ops.cast(weights, embeddings.dtype) + + # Reshape weights to allow broadcast + ones_shape = array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0) + ones = array_ops.ones(ones_shape, dtype=dtypes.int32) + bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0) + + orig_weights_shape = weights.get_shape() + weights = array_ops.reshape(weights, bcast_weights_shape) + + # Set the weight shape, since after reshaping to bcast_weights_shape, + # the shape becomes None. + if embeddings.get_shape().ndims is not None: + weights.set_shape( + orig_weights_shape.concatenate( + [1 for _ in range(embeddings.get_shape().ndims - 1)])) + + embeddings *= weights + + if combiner == "sum": + embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) + elif combiner == "mean": + embeddings = math_ops.segment_sum(embeddings, segment_ids) + weight_sum = math_ops.segment_sum(weights, segment_ids) + embeddings = math_ops.div_no_nan(embeddings, weight_sum, name=name) + elif combiner == "sqrtn": + embeddings = math_ops.segment_sum(embeddings, segment_ids) + weights_squared = math_ops.pow(weights, 2) + weight_sum = math_ops.segment_sum(weights_squared, segment_ids) + weight_sum_sqrt = math_ops.sqrt(weight_sum) + embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt, name=name) + else: + assert False, "Unrecognized combiner" + if embeddings.dtype != original_dtype: + embeddings = math_ops.cast(embeddings, original_dtype) + else: + if segment_ids.dtype not in (dtypes.int32, dtypes.int64): + segment_ids = math_ops.cast(segment_ids, dtypes.int32) + assert idx is not None + if combiner == "sum": + embeddings = math_ops.sparse_segment_sum( + embeddings, + idx, + segment_ids, + name=name, + sparse_gradient=need_sparse_segment_gradient, + ) + elif combiner == "mean": + embeddings = math_ops.sparse_segment_mean( + embeddings, + idx, + segment_ids, + name=name, + sparse_gradient=need_sparse_segment_gradient, + ) + elif combiner == "sqrtn": + embeddings = math_ops.sparse_segment_sqrt_n( + embeddings, + idx, + segment_ids, + name=name, + sparse_gradient=need_sparse_segment_gradient, + ) + else: + assert False, "Unrecognized combiner" + + return embeddings + + +def embedding_lookup_sparse( + params: ShadowVariable, + sp_ids: ragged_tensor.Ragged, + sp_weights, + partition_strategy=None, # no used + name="embedding_lookup_sparse", + combiner="mean", + max_norm=None, # no used + return_trainable=False, # no used +): + """Looks up embeddings for the given ids and weights from a list of tensors. + + This op assumes that there is at least one id for each row in the dense tensor + represented by sp_ids (i.e. there are no rows with empty features), and that + all the indices of sp_ids are in canonical row-major order. + + `sp_ids` and `sp_weights` (if not None) are `RaggedTensor`s with rank of 2. + Embeddings are always aggregated along the last dimension. + + It also assumes that all id values lie in the range [0, p0), where p0 + is the sum of the size of params along dimension 0. + + Args: + params: A single tensor representing the complete embedding tensor, or a + list tensors all of same shape except for the first dimension, + representing sharded embedding tensors. Alternatively, a + `PartitionedVariable`, created by partitioning along dimension 0. Each + element must be appropriately sized for the given `partition_strategy`. + sp_ids: `RaggedTensor` with rank 2. The rank is not verified for performance + reasons. + sparse_weights: `RaggedTensor` of same type and shape as `sparse_ids`, + containing float / double weights corresponding to `sparse_ids`, or `None` + if all weights are assumed to be 1.0. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. See `tf.nn.embedding_lookup` for more details. + name: Optional name for the op. + combiner: A string specifying the reduction op. Currently "mean", "sqrtn" + and "sum" are supported. "sum" computes the weighted sum of the embedding + results for each row. "mean" is the weighted sum divided by the total + weight. "sqrtn" is the weighted sum divided by the square root of the sum + of the squares of the weights. Defaults to `mean`. + max_norm: If not `None`, each embedding is clipped if its l2-norm is larger + than this value, before combining. + allow_fast_lookup: An optional boolean specifying whether to allow + simplified embedding lookups when `params` is a single tensor and + `max_norm` is `None`. Setting this flag to `True` during training can + cause the use of dense gradients with increased memory footprint. + + Returns: + A dense tensor representing the combined embeddings for the + sparse ids. For each row in the dense tensor represented by `sp_ids`, the op + looks up the embeddings for all ids in that row, multiplies them by the + corresponding weight, and combines these embeddings as specified. + + In other words, if + + `shape(combined params) = [p0, p1, ..., pm]` + + and + + `shape(sp_ids) = shape(sp_weights) = [d0, d1]` + + then + + `shape(output) = [d0, p1, ..., pm]`. + + For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are + + ```python + [0, 0]: id 1, weight 2.0 + [0, 1]: id 3, weight 0.5 + [1, 0]: id 0, weight 1.0 + [2, 3]: id 1, weight 3.0 + ``` + + with `combiner`="mean", then the output will be a 3x20 matrix where + + ```python + output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) + output[1, :] = (params[0, :] * 1.0) / 1.0 + output[2, :] = (params[1, :] * 3.0) / 3.0 + ``` + + Raises: + TypeError: If `sp_weights` is neither `None` nor of the same type as + `sp_ids`. + ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. + """ + rt_ids = sp_ids + rt_weights = sp_weights + if combiner is None: + combiner = "mean" + if combiner not in ("mean", "sqrtn", "sum"): + raise ValueError( + f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}") + + ignore_weights = rt_weights is None + if not ignore_weights: + if not isinstance(rt_weights, ragged_tensor.RaggedTensor): + raise TypeError(f"sp_ids must be of the same type as sp_weights, " + f"received {{type(sp_ids).__name__!r}} for sp_ids and " + f"{{type(sp_weights).__name__!r}} for sp_weights.") + rt_ids.values.get_shape().assert_is_compatible_with( + rt_weights.values.get_shape()) + rt_ids.get_shape().assert_is_compatible_with(rt_weights.get_shape()) + # + with ops.name_scope(name, "embedding_lookup_sparse") as name: + segment_ids = rt_ids.value_rowids() + ids = rt_ids.flat_values + return _embedding_lookup_sparse_impl( + params, + segment_ids, + sp_weights, + ids, + combiner, + ignore_weights, + name, + ) + + +def safe_embedding_lookup_sparse( + embedding_weights, + sparse_ids: ragged_tensor.Ragged, + sparse_weights=None, + combiner="mean", + default_id=None, + name=None, +): + """Lookup embedding results, accounting for invalid IDs and empty features. + + The partitioned embedding in `embedding_weights` must all be the same shape + except for the first dimension. The first dimension is allowed to vary as the + vocabulary size is not necessarily a multiple of `P`. `embedding_weights` + may be a `PartitionedVariable` as returned by using + `tf.compat.v1.get_variable()` with a + partitioner. + + Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs + with non-positive weight. For an entry with no features, the embedding vector + for `default_id` is returned, or the 0-vector if `default_id` is not supplied. + + The ids and weights may be multi-dimensional `SparseTensor`s or + `RaggedTensor`s with rank of 2. For `SpareTensor`s with left-aligned non-zero + entries which can be described as `RaggedTensor`s, use of `RaggedTensor`s can + yield higher performance. Embeddings are always aggregated along the last + dimension. + + Args: + embedding_weights: A single tensor representing the complete embedding + tensor, or a list tensors all of same shape except for the first + dimension, representing sharded embedding tensors. Alternatively, a + `PartitionedVariable`, created by partitioning along dimension 0. Each + element must be appropriately sized for the given `partition_strategy`. + sp_ids: `RaggedTensor` with rank 2. The rank is not verified for performance + reasons. + sparse_weights: `RaggedTensor` of same type and shape as `sparse_ids`, + containing float weights corresponding to `sparse_ids`, or `None` if all + weights are assumed to be 1.0. + combiner: A string specifying how to combine embedding results for each + entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the + default. + default_id: The id to use for an entry with no features. + name: A name for this operation (optional). + partition_strategy: A string specifying the partitioning strategy. Currently + `"div"` and `"mod"` are supported. Default is `"div"`. + max_norm: If not `None`, all embeddings are l2-normalized to max_norm before + combining. + allow_fast_lookup: An optional boolean specifying whether to allow + simplified embedding lookups when `params` is a single tensor and + `max_norm` is `None`. Setting this flag to `True` during training can + cause the use of dense gradients with increased memory footprint. + + Returns: + A dense tensor representing the combined embeddings for the + sparse ids. For each row in the dense tensor represented by `sp_ids`, the op + looks up the embeddings for all ids in that row, multiplies them by the + corresponding weight, and combines these embeddings as specified. + + In other words, if + + `shape(combined embedding_weights) = [p0, p1, ..., pm]` + + and + + `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` + + then + + `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. + + For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are + + ```python + [0, 0]: id 1, weight 2.0 + [0, 1]: id 3, weight 0.5 + [1, 0]: id -1, weight 1.0 + [2, 3]: id 1, weight 3.0 + ``` + + `default_id` is 0. + + with `combiner`="mean", then the output will be a 3x20 matrix where + + ```python + output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) + output[1, :] = (params[0, :] * 1.0) / 1.0 + output[2, :] = (params[1, :] * 3.0) / 3.0 + ``` + + Raises: + ValueError: if `embedding_weights` is empty. + """ + ragged_ids = sparse_ids + ragged_weights = sparse_weights + if isinstance(embedding_weights, de.shadow_ops.ShadowVariable): + verify_embedding_weights(embedding_weights.params, sparse_ids, + sparse_weights) + else: + verify_embedding_weights(embedding_weights, sparse_ids, sparse_weights) + with ops.name_scope(name, "embedding_lookup", + [ragged_ids, ragged_weights]) as scope: + + if combiner != "sum": + ragged_ids, ragged_weights = _prune_invalid_weights_ragged( + ragged_ids, ragged_weights) + ragged_ids, is_row_empty = _fill_empty_rows(ragged_ids, default_id or 0) + if ragged_weights is not None: + ragged_weights, _ = _fill_empty_rows(ragged_weights, 1.0) + + result = embedding_lookup_sparse( + embedding_weights, + ragged_ids, + ragged_weights, + combiner=combiner, + name=None if default_id is None else scope, + ) + + if default_id is None: + # Broadcast is_row_empty to the same shape as embedding_lookup_result, + # for use in Select. + is_row_empty = array_ops.tile( + array_ops.reshape(is_row_empty, [-1, 1]), + tf.stack([1, array_ops.shape(result)[1]]), + ) + + result = array_ops.where(is_row_empty, + array_ops.zeros_like(result), + result, + name=scope) + + return result + + +def _prune_invalid_weights_ragged(ids, weights): + """Prune invalid weights (< 0) from the input ids and weights.""" + if weights is not None: + is_weights_valid = math_ops.greater(weights.values, 0) + nrows = ids.nrows() + # TODO(philipphack): Consider calling ragged_array_ops.boolean_mask once the + # resulting performance is comparable to array_ops.boolean_mask. Currently, + # ragged_array_ops.boolean_mask constructs the returned RaggedTensor by + # calling its from_row_splits method which does not set value_row_ids and + # requires it to be computed on demand. + pruned_values = array_ops.boolean_mask_v2(ids.values, is_weights_valid) + pruned_value_rowids = array_ops.boolean_mask_v2(ids.value_rowids(), + is_weights_valid) + ids = ragged_tensor.RaggedTensor.from_value_rowids(pruned_values, + pruned_value_rowids, + nrows=nrows, + validate=False) + + pruned_weights_values = array_ops.boolean_mask_v2(weights.values, + is_weights_valid) + weights = ragged_tensor.RaggedTensor.from_value_rowids( + pruned_weights_values, pruned_value_rowids, nrows=nrows, validate=False) + + return ids, weights diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py index 047507506..b5e1e48f2 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py @@ -33,7 +33,7 @@ and modular style development, like keras. """ -import functools +import tensorflow as tf from tensorflow_recommenders_addons import dynamic_embedding as de from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import DEResourceVariable @@ -44,7 +44,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import resource_variable_ops try: # tf version >= 2.10.0 @@ -53,7 +52,6 @@ from tensorflow.python.training.tracking import base as trackable from tensorflow.python.distribute import distribute_utils -from tensorflow.python.distribute import values_util as distribute_values_util class ShadowVariable(de.TrainableWrapper): @@ -222,7 +220,7 @@ def _gather_saveables_for_checkpoint(self): def embedding_lookup( - shadow, + shadow: ShadowVariable, ids, partition_strategy=None, # pylint: disable=unused-argument name=None, @@ -259,6 +257,55 @@ def embedding_lookup( with ops.colocate_with(None, ignore_existing=True): if de.ModelMode.CURRENT_SETTING == de.ModelMode.TRAIN: with ops.control_dependencies([shadow_._reset_ids(ids)]): - return shadow_.read_value(do_prefetch=True) + result = shadow_.read_value(do_prefetch=True) else: - return shadow_.params.lookup(ids) + result = shadow_.params.lookup(ids) + + return result + + +def embedding_lookup_unique( + shadow, + ids, + embedding_size, + with_unique=True, + name=None, +): + """ + unify version of embedding_lookup. It handles ragged tensor, unique and shape. No by-product will + be introduced in this call. So it can be decorated by `tf.function`. + + Args: + shadow: A ShadowVariable object. + ids: A tensor with any shape as same dtype of params.key_dtype. + embedding_size: The size of embedding, used in shape the output + with_unique: If True, it will use unique ids to lookup embedding. + name: A name for the operation. + + Returns: + A tensor with shape [shape of ids] + [embedding_size], + containing the values from the params tensor(s) for keys in ids. + """ + is_ragged = isinstance(ids, tf.RaggedTensor) + + if is_ragged: + original_structure = ids + ids = ids.flat_values + else: + ids = tf.convert_to_tensor(ids) + input_shape = tf.shape(ids) + embeddings_shape = tf.concat([input_shape, [embedding_size]], 0) + ids_flat = tf.reshape(ids, (-1,)) + if with_unique: + with ops.name_scope(name, "EmbeddingWithUnique"): + unique_ids, idx = tf.unique(ids_flat) + unique_embeddings = embedding_lookup(shadow, unique_ids) + embeddings_flat = tf.gather(unique_embeddings, idx) + else: + embeddings_flat = embedding_lookup(shadow, ids_flat) + embeddings = tf.reshape(embeddings_flat, embeddings_shape) + + if is_ragged: + embeddings = tf.RaggedTensor.from_row_lengths( + embeddings, original_structure.row_lengths()) + return embeddings