diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py index a777551b3..f1280dba6 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py @@ -751,7 +751,6 @@ def test_embedding_lookup_unique(self): np.testing.assert_almost_equal(embedded_np, embedded_de) -@test_util.run_all_in_graph_and_eager_modes class EmbeddingLookupSparseTest(test.TestCase, parameterized.TestCase): def _random_ids_and_weights(self, @@ -801,6 +800,7 @@ def _group_by_batch_entry(self, vals, vals_per_batch_entry): index += num_val return grouped_vals + @test_util.run_all_in_graph_and_eager_modes @parameterized.parameters(itertools.product([True, False])) def test_embedding_lookup_sparse(self, ragged): var_id = 0 @@ -863,8 +863,6 @@ def test_embedding_lookup_sparse(self, ragged): ) else random_init.eval() grouped_params = self._group_by_batch_entry(np_params, vals_per_batch_entry) - if context.executing_eagerly(): - params = de.shadow_ops.ShadowVariable(params) if ragged: embedding_sum = embedding_lookup_sparse( params, @@ -900,6 +898,7 @@ def test_embedding_lookup_sparse(self, ragged): atol = rtol self.assertAllClose(np_embedding_sum, tf_embedding_sum, rtol, atol) + @test_util.run_all_in_graph_and_eager_modes def test_embedding_lookup_sparse_shape_checking(self): if context.executing_eagerly(): self.skipTest("Skip eager test") @@ -920,7 +919,6 @@ def test_embedding_lookup_sparse_shape_checking(self): embedding_lookup_test.get_shape().as_list()) -@test_util.run_all_in_graph_and_eager_modes class SafeEmbeddingLookupSparseTest(test.TestCase, parameterized.TestCase): def _get_ids_and_weights_3d(self, valid_ids): @@ -932,10 +930,9 @@ def _get_ids_and_weights_3d(self, valid_ids): embedding_weights_values = embedding_weights_values.numpy( ) if context.executing_eagerly() else embedding_weights_values.eval() self.evaluate(embedding_weights.upsert(valid_ids, embedding_weights_values)) - if context.executing_eagerly(): - embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) return embedding_weights, embedding_weights_values, sparse_ids, sparse_weights + @test_util.run_all_in_graph_and_eager_modes @parameterized.parameters(itertools.product([True, False])) def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged=False): with self.session(use_gpu=test_util.is_gpu_available(), @@ -958,9 +955,6 @@ def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged=False): self.evaluate( embedding_weights.upsert(valid_ids, embedding_weights_values)) - # check - if context.executing_eagerly(): - embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) if ragged: embedding_lookup_result = safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights) @@ -983,6 +977,7 @@ def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged=False): ], ) + @test_util.run_all_in_graph_and_eager_modes @parameterized.parameters(itertools.product([True, False])) def test_safe_embedding_lookup_sparse_return_special_vector( self, ragged=False): @@ -1000,10 +995,6 @@ def test_safe_embedding_lookup_sparse_return_special_vector( ) else weights.eval() self.evaluate( embedding_weights.upsert(valid_ids, embedding_weights_values)) - - # check - if context.executing_eagerly(): - embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) if ragged: embedding_lookup_result = safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights, default_id=3) @@ -1025,6 +1016,7 @@ def test_safe_embedding_lookup_sparse_return_special_vector( ], ) + @test_util.run_all_in_graph_and_eager_modes @parameterized.parameters(itertools.product([True, False])) def test_safe_embedding_lookup_sparse_no_weights(self, ragged=False): with self.session(use_gpu=test_util.is_gpu_available(), @@ -1041,9 +1033,6 @@ def test_safe_embedding_lookup_sparse_no_weights(self, ragged=False): ) else weights.eval() self.evaluate( embedding_weights.upsert(valid_ids, embedding_weights_values)) - - if context.executing_eagerly(): - embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) if ragged: embedding_lookup_result = safe_embedding_lookup_sparse( embedding_weights, sparse_ids, None) @@ -1065,6 +1054,7 @@ def test_safe_embedding_lookup_sparse_no_weights(self, ragged=False): ], ) + @test_util.run_all_in_graph_and_eager_modes @parameterized.parameters(itertools.product([True, False])) def test_safe_embedding_lookup_sparse_partitioned(self, ragged=False): with self.session(use_gpu=test_util.is_gpu_available(), @@ -1081,9 +1071,6 @@ def test_safe_embedding_lookup_sparse_partitioned(self, ragged=False): ) else weights.eval() self.evaluate( embedding_weights.upsert(valid_ids, embedding_weights_values)) - - if context.executing_eagerly(): - embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) if ragged: embedding_lookup_result = safe_embedding_lookup_sparse( embedding_weights, sparse_ids, None) @@ -1105,6 +1092,7 @@ def test_safe_embedding_lookup_sparse_partitioned(self, ragged=False): ], ) + @test_util.run_all_in_graph_and_eager_modes @parameterized.parameters(itertools.product([True, False])) def test_safe_embedding_lookup_sparse_inconsistent_ids_type( self, ragged=False): @@ -1115,8 +1103,6 @@ def fn(): embedding_weights = _random_weights(num_shards=3, key_dtype=dtypes.int32) sparse_ids, sparse_weights = _ids_and_weights_2d(ragged=ragged) - if context.executing_eagerly(): - embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) if ragged: safe_embedding_lookup_sparse(embedding_weights, sparse_ids, sparse_weights) @@ -1126,6 +1112,7 @@ def fn(): self.assertRaises(TypeError, fn) + @test_util.run_all_in_graph_and_eager_modes @parameterized.parameters(itertools.product([True, False])) def test_safe_embedding_lookup_sparse_inconsistent_weights_type( self, ragged=False): @@ -1135,8 +1122,6 @@ def test_safe_embedding_lookup_sparse_inconsistent_weights_type( def fn(): embedding_weights = _random_weights(num_shards=3, key_dtype=dtypes.half) sparse_ids, sparse_weights = _ids_and_weights_2d(ragged=ragged) - if context.executing_eagerly(): - embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) if ragged: safe_embedding_lookup_sparse(embedding_weights, sparse_ids, sparse_weights) @@ -1146,6 +1131,7 @@ def fn(): self.assertRaises(TypeError, fn) + @test_util.run_all_in_graph_and_eager_modes def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): @@ -1172,6 +1158,7 @@ def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): ], ) + @test_util.run_all_in_graph_and_eager_modes def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): @@ -1199,6 +1186,7 @@ def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): ], ) + @test_util.run_all_in_graph_and_eager_modes def test_safe_embedding_lookup_sparse_3d_no_weights(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): @@ -1227,6 +1215,7 @@ def test_safe_embedding_lookup_sparse_3d_no_weights(self): ], ) + @test_util.run_all_in_graph_and_eager_modes def test_safe_embedding_lookup_sparse_3d_partitioned(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): @@ -1240,8 +1229,6 @@ def test_safe_embedding_lookup_sparse_3d_partitioned(self): ) if context.executing_eagerly() else embedding_weights_values.eval() self.evaluate( embedding_weights.upsert(valid_ids, embedding_weights_values)) - if context.executing_eagerly(): - embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) embedding_lookup_result = de.safe_embedding_lookup_sparse( embedding_weights, sparse_ids, None) embedding_lookup_result = embedding_lookup_result.numpy( @@ -1265,6 +1252,7 @@ def test_safe_embedding_lookup_sparse_3d_partitioned(self): ], ) + @test_util.run_all_in_graph_and_eager_modes def test_safe_embedding_lookup_sparse_with_initializer(self): id = 0 embed_dim = 8 @@ -1315,9 +1303,6 @@ def test_safe_embedding_lookup_sparse_with_initializer(self): constant_op.constant(ids, dtypes.int64), constant_op.constant(dense_shape, dtypes.int64), ) - if context.executing_eagerly(): - embedding_weights = de.shadow_ops.ShadowVariable(embedding_weights) - vals_op = de.safe_embedding_lookup_sparse(embedding_weights, sparse_ids, None, @@ -1333,6 +1318,7 @@ def test_safe_embedding_lookup_sparse_with_initializer(self): self.assertAllClose(target_mean, mean, rtol, atol) self.assertAllClose(target_stddev, stddev, rtol, atol) + @test_util.run_all_in_graph_and_eager_modes def test_safe_embedding_lookup_sparse_shape_checking(self): if context.executing_eagerly(): self.skipTest("Skip eager test") @@ -1354,6 +1340,7 @@ def test_safe_embedding_lookup_sparse_shape_checking(self): self.assertAllEqual(embedding_lookup_base.get_shape(), embedding_lookup_test.get_shape()) + @test_util.run_all_in_graph_and_eager_modes def test_dynamic_embedding_variable_clear(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py index 8a4c7b68a..bcab3009a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py @@ -17,6 +17,7 @@ Dynamic Embedding is designed for Large-scale Sparse Weights Training. See [Sparse Domain Isolation](https://github.com/tensorflow/community/pull/237) """ +from tensorflow.python.ops.variables import VariableAggregation from tensorflow_recommenders_addons import dynamic_embedding as de from tensorflow_recommenders_addons.utils.resource_loader import get_tf_version_triple @@ -62,7 +63,6 @@ except: from tensorflow.python.training.tracking import data_structures from tensorflow.python.util import compat, dispatch -from tensorflow.python.util.tf_export import tf_export from tensorflow.python.keras.utils import tf_utils try: # tf version >= 2.14.0 @@ -649,7 +649,7 @@ def _create_or_get_trainable(trainable_name): with ops.colocate_with(ids, ignore_existing=True): if distribute_ctx.has_strategy(): - trainable_ = _distribute_trainable_store.get(name, None) + trainable_ = params._distribute_trainable_store.get(name, None) if trainable_ is None: strategy_devices = distribute_ctx.get_strategy( ).extended.worker_devices @@ -752,7 +752,7 @@ def embedding_lookup_unique(params, def embedding_lookup_sparse( - params, + embedding_weights, sp_ids, sp_weights, partition_strategy=None, # no used @@ -772,7 +772,7 @@ def embedding_lookup_sparse( is the sum of the size of params along dimension 0. Args: - params: A single `dynamic_embedding.Variable` instance representing + embedding_weights: A single `dynamic_embedding.Variable` instance representing the complete embedding tensor or a `ShadowVariable` instance. sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size and M is arbitrary. @@ -855,16 +855,16 @@ def embedding_lookup_sparse( ids = sp_ids.values ids, idx = array_ops.unique(ids) - if isinstance(params, de.shadow_ops.ShadowVariable): + if isinstance(embedding_weights, de.shadow_ops.ShadowVariable): embeddings = de.shadow_ops.embedding_lookup( - params, + embedding_weights, ids, name=name + '/embedding_lookup', ) - trainable_ = params + trainable_ = embedding_weights else: embeddings, trainable_ = embedding_lookup( - params, + embedding_weights, ids, name=name + '/embedding_lookup', partition_strategy=partition_strategy, diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/ragged_embedding_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/ragged_embedding_ops.py index 2c6764c35..ed92d1349 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/ragged_embedding_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/ragged_embedding_ops.py @@ -1,4 +1,5 @@ import tensorflow as tf +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes, ops from tensorflow.python.ops import resource_variable_ops, array_ops, math_ops, gen_ragged_array_ops, gen_math_ops from tensorflow.python.ops.bincount_ops import validate_dense_weights @@ -19,7 +20,7 @@ def _bincount(arr, binary_output=False): name = "bincount" if name is None else name - with ops.name_scope(name): + with tf.name_scope(name): arr = tf.convert_to_tensor(arr, name="arr") if weights is not None: weights = tf.convert_to_tensor(weights, name="weights") @@ -144,7 +145,10 @@ def _embedding_lookup_sparse_impl( if isinstance(params, de.shadow_ops.ShadowVariable): embeddings = de.shadow_ops.embedding_lookup(params, ids) else: - embeddings = de.embedding_lookup(params, ids) + if context.executing_eagerly(): + embeddings = de.embedding_lookup(params, ids, name=name) + else: + embeddings = de.embedding_lookup(params, ids) if not ignore_weights: if segment_ids.dtype != dtypes.int32: @@ -314,8 +318,8 @@ def embedding_lookup_sparse( rt_ids.values.get_shape().assert_is_compatible_with( rt_weights.values.get_shape()) rt_ids.get_shape().assert_is_compatible_with(rt_weights.get_shape()) - # - with ops.name_scope(name, "embedding_lookup_sparse") as name: + + with tf.name_scope(name or "embedding_lookup_sparse") as name: segment_ids = rt_ids.value_rowids() ids = rt_ids.flat_values return _embedding_lookup_sparse_impl(