From 8d68d791b6d5284100d92cfa97ae3b871dc64024 Mon Sep 17 00:00:00 2001 From: MoFHeka Date: Thu, 14 Dec 2023 10:11:37 +0800 Subject: [PATCH] [fix] Embedding call didn't return a single worker call function when hvd.size() is 0. --- .../dynamic_embedding/python/keras/layers/embedding.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py index ce7875149..47252ec6c 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import ops from tensorflow.python.eager import tape from tensorflow.python.ops.variables import VariableAggregation +from tensorflow.python.platform import tf_logging try: # The data_structures has been moved to the new package in tf 2.11 from tensorflow.python.trackable import data_structures except: @@ -41,6 +42,7 @@ from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import DistributedVariableWrapper, TrainableWrapperDistributedPolicy from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_variable import make_partition +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.tf_save_restore_patch import de_fs_saveable_class_names def _choose_reduce_method(combiner, sparse=False, segmented=False): @@ -245,6 +247,7 @@ def __init__(self, TrainableWrapperDistributedPolicy(VariableAggregation.NONE)) else: self.shadow = self.shadow_impl.as_list()[0] + self.params._created_in_class = self # To facilitate access to the primitive class through params super(Embedding, self).__init__(name=name, trainable=trainable, dtype=value_dtype) @@ -550,7 +553,11 @@ def __init__(self, else: self._mpi_size = mpi_size super(HvdAllToAllEmbedding, self).__init__(*args, **kwargs) - self.params._created_in_class = self + if type(self.params.saveable).__name__ not in de_fs_saveable_class_names: + tf_logging.warning( + "Please use FileSystemSaver in KVCreator when use HvdAllToAllEmbedding. " + "It will allow TFRA save and restore KV files when Embedding tensor parallel in distributed training. " + ) def __relocate_dense_feature__(self, ids, batch_size=None): """