Skip to content

Commit

Permalink
[fix] Embedding call didn't return a single worker call function when…
Browse files Browse the repository at this point in the history
… hvd.size() is 0.
  • Loading branch information
MoFHeka committed Dec 14, 2023
1 parent 754e2e8 commit 8d68d79
Showing 1 changed file with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@
from tensorflow.python.framework import ops
from tensorflow.python.eager import tape
from tensorflow.python.ops.variables import VariableAggregation
from tensorflow.python.platform import tf_logging
try: # The data_structures has been moved to the new package in tf 2.11
from tensorflow.python.trackable import data_structures
except:
from tensorflow.python.training.tracking import data_structures

from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import DistributedVariableWrapper, TrainableWrapperDistributedPolicy
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_variable import make_partition
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.tf_save_restore_patch import de_fs_saveable_class_names


def _choose_reduce_method(combiner, sparse=False, segmented=False):
Expand Down Expand Up @@ -245,6 +247,7 @@ def __init__(self,
TrainableWrapperDistributedPolicy(VariableAggregation.NONE))
else:
self.shadow = self.shadow_impl.as_list()[0]
self.params._created_in_class = self # To facilitate access to the primitive class through params
super(Embedding, self).__init__(name=name,
trainable=trainable,
dtype=value_dtype)
Expand Down Expand Up @@ -550,7 +553,11 @@ def __init__(self,
else:
self._mpi_size = mpi_size
super(HvdAllToAllEmbedding, self).__init__(*args, **kwargs)
self.params._created_in_class = self
if type(self.params.saveable).__name__ not in de_fs_saveable_class_names:
tf_logging.warning(
"Please use FileSystemSaver in KVCreator when use HvdAllToAllEmbedding. "
"It will allow TFRA save and restore KV files when Embedding tensor parallel in distributed training. "
)

def __relocate_dense_feature__(self, ids, batch_size=None):
"""
Expand Down

0 comments on commit 8d68d79

Please sign in to comment.