diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py index d83e4c2df..f6f974170 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py @@ -20,6 +20,7 @@ import itertools import os +import numpy as np import shutil import tensorflow as tf @@ -34,6 +35,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import adam +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import monitored_session from tensorflow.python.training.optimizer import Optimizer as tf1_opt from tensorflow.python.training import training_util @@ -97,6 +99,8 @@ def test_all_to_all_embedding_trainable(self): self.common_all_to_all_embedding_trainable_v2(keras_base_opt, keras_test_opt, name="keras_adam") + self.common_lazy_build_model_with_checkpoint_management_v2( + name="keras_adam_lazy_build") def common_minimize_trainable_v1(self, base_opt, test_opt, name): # TODO(rhdong): Recover the testing, if the horovod import error is fixed on macOS+TF2.7+. @@ -326,20 +330,23 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name): de.keras.models.de_hvd_save_model(base_model, save_dir, options=save_options) - ckpt = de.train.DEHvdCheckpoint(base_model) + ckpt = de.train.DECheckpoint( + my_model=base_model) # Test custom model key "my_model" ckpt.save(save_dir + '/ckpt/test') - tf.keras.backend.clear_session() del base_model + del base_opt + tf.keras.backend.clear_session() + new_opt = de.DynamicEmbeddingOptimizer(Adam(1.1), synchronous=True) new_base_model = get_emb_sequential_model( de.keras.layers.HvdAllToAllEmbedding, - base_opt, + new_opt, dense_init='ones', embedding_size=dim, initializer=init, bp_v2=False, kv_creator=kv_creator, name='all2all_emb') - ckpt = de.train.DEHvdCheckpoint(my_model=new_base_model) + ckpt = de.train.DECheckpoint(my_model=new_base_model) hvd.join() # Sync for avoiding files conflict ckpt.restore(tf.train.latest_checkpoint(save_dir + '/ckpt/')) new_a2aemb_size = new_base_model.layers[0].params.size() @@ -351,6 +358,165 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name): self.assertEqual(a2aemb_size, new_a2aemb_size) hvd.join() # Sync for avoiding files conflict + def common_lazy_build_model_with_checkpoint_management_v2(self, name): + # TODO(rhdong): Recover the testing, if the horovod import error is fixed on macOS+TF2.7+. + try: + import horovod.tensorflow as hvd + except (NotFoundError): + self.skipTest( + "Skip the test for horovod import error with Tensorflow-2.7.0 on MacOS-12." + ) + + tf.config.set_soft_device_placement(True) + + hvd.init() + + # These cases need 2 GPUs at least if available. + logical_devices = tf.config.list_logical_devices('GPU') + _device = "GPU" if len(logical_devices) >= hvd.size() else "CPU" + _device_id = hvd.local_rank( + ) if _device == "GPU" and len(logical_devices) >= 2 else 0 + + if _device == "GPU": + os.environ["CUDA_VISIBLE_DEVICES"] = str(_device_id) + + dim = 8 + + class NoCompileModel(tf.keras.models.Model): + + def __init__(self, init, dynamic=False): + super().__init__(dynamic=dynamic) + kv_creator = de.CuckooHashTableCreator(saver=de.FileSystemSaver( + proc_size=hvd.size(), proc_rank=hvd.rank())) + self.emb = de.keras.layers.HvdAllToAllEmbedding(embedding_size=dim, + devices=['/GPU:0'], + initializer=0, + kv_creator=kv_creator, + name=name) + self.l1 = tf.keras.layers.Dense(8, 'relu', kernel_initializer=init) + self.l2 = tf.keras.layers.Dense(1, 'sigmoid', kernel_initializer=init) + + def build(self, input_shape): + self.emb.build(input_shape) + self.l1.build(input_shape + dim) + self.l2.build(input_shape + 8) + + def call(self, x): + out = self.emb(x) + out = self.l1(out) + return self.l2(out) + + def check_TFRADynamicEmbedding_directory(save_dir, + save_it, + should_be_exist=True): + hvd_size = hvd.size() + if hvd_size <= 1: + hvd_size = 1 + for tag in ['keys', 'values']: + for rank in range(hvd_size): + self.assertTrue(not (os.path.exists( + save_dir + + f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + ) ^ should_be_exist)) + self.assertTrue(not (os.path.exists( + save_dir + + f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + ) ^ should_be_exist)) + # f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + self.assertTrue(not (os.path.exists( + save_dir + + f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + ) ^ should_be_exist)) + # f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + + with tf.device("/{}:{}".format(_device, _device_id)): + x = tf.reshape(tf.range(0, 32, dtype=tf.int64), [32, 1]) + y = tf.random.uniform(shape=[32, 1]) + + save_dir = self.get_temp_dir() + + model = NoCompileModel('ones') + base_opt = Adam(1.0) + base_opt = de.DynamicEmbeddingOptimizer(base_opt, synchronous=True) + ckpt = de.train.DECheckpoint(model=model, optimizer=base_opt) + model.compile(optimizer=base_opt, loss='mean_absolute_error') + manager = checkpoint_management.CheckpointManager(ckpt, + save_dir, + max_to_keep=1) + model.fit(x, y, verbose=0) + manager.save() + if hvd.rank() == 0: + check_TFRADynamicEmbedding_directory(save_dir, + save_it=1, + should_be_exist=True) + for l in model.layers: + if name in l.name: + l.params.upsert(x * 10, tf.random.uniform(shape=[32, 1, dim])) + emb_size = l.params.size() + emb_keys, emb_values = l.params.export() + break + for v in base_opt.variables(): + if name in v.name: + v.params.upsert(x * 10, tf.random.uniform(shape=[32, 1, dim])) + opt_size = v.params.size() + opt_keys, opt_values = l.params.export() + break + manager.save() + if hvd.rank() == 0: + check_TFRADynamicEmbedding_directory(save_dir, + save_it=2, + should_be_exist=True) + # CheckpointManager delete checkpoint after the write functuon, but DE KV checkpoint saving and deleting inside the write functuon. + # So DE KV checkpoint TFRADynamicEmbedding directory will be always one more than TF checkpoint file. + manager.save() + if hvd.rank() == 0: + check_TFRADynamicEmbedding_directory( + save_dir, save_it=1, should_be_exist=False + ) # Check delete TFRADynamicEmbedding directory properly. + + del base_opt + del model + del ckpt + tf.keras.backend.clear_session() + tf.compat.v1.reset_default_graph() + + new_model = NoCompileModel('zeros') + new_opt = Adam(1.1) + new_opt = de.DynamicEmbeddingOptimizer(new_opt, synchronous=True) + new_ckpt = de.train.DECheckpoint(model=new_model, optimizer=new_opt) + manager = checkpoint_management.CheckpointManager(new_ckpt, + save_dir, + max_to_keep=1) + manager.restore_or_initialize() + new_model.compile(optimizer=new_opt, loss='mean_absolute_error') + new_model(x) # Build vairiables + try: + new_opt._create_all_weights(new_model.variables) + except: + #TODO(MoFHejia) raise ValueError: Cannot convert a partially known TensorShape to a Tensor. + pass + for l in new_model.layers: + if name in l.name: + new_emb_size = l.params.size() + new_emb_keys, new_emb_values = l.params.export() + break + for v in new_opt.variables(): + if name in v.name: + new_opt_size = v.params.size() + new_opt_keys, new_opt_values = l.params.export() + break + + self.assertEqual(emb_size, new_emb_size) + self.assertEqual(opt_size, new_opt_size) + self.assertAllEqual(np.sort(emb_keys, axis=0), + np.sort(new_emb_keys, axis=0)) + self.assertAllClose(np.sort(emb_values, axis=0), + np.sort(new_emb_values, axis=0)) + self.assertAllEqual(np.sort(opt_keys, axis=0), + np.sort(new_opt_keys, axis=0)) + self.assertAllClose(np.sort(opt_values, axis=0), + np.sort(new_opt_values, axis=0)) + if __name__ == "__main__": test.main() diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/shadow_embedding_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/shadow_embedding_ops_test.py index ce17eb627..f980fc570 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/shadow_embedding_ops_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/shadow_embedding_ops_test.py @@ -560,16 +560,16 @@ def size(self): model_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) save_ckpt_dir = os.path.join(model_dir, 'model') - restore_ckpt_path = os.path.join(model_dir, 'model-1') options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) - ckpt = tf.train.Checkpoint(module) + ckpt = de.train.DECheckpoint(module) ckpt.save(save_ckpt_dir) shadow_value = module.shadow.read_value(False) self.assertAllEqual(shadow_value.shape, (0, 2)) # clear when saving new_module = TestModule() - new_ckpt = tf.train.Checkpoint(new_module) + new_ckpt = de.train.DECheckpoint(new_module) + restore_ckpt_path = tf.train.latest_checkpoint(model_dir) new_ckpt.read(restore_ckpt_path) self.assertEqual(new_module.size(), 3) expected_values = module(keys) @@ -640,10 +640,9 @@ def size(self): model_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) save_ckpt_dir = os.path.join(model_dir, 'model') - restore_ckpt_path = os.path.join(model_dir, 'model-1') options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) - ckpt = tf.train.Checkpoint(module) + ckpt = de.train.DECheckpoint(module) ckpt.save(save_ckpt_dir) shadow_value = module.shadow.read_value(False) self.assertAllEqual(shadow_value.shape, (0, 1)) # clear when saving @@ -651,7 +650,8 @@ def size(self): tf.keras.backend.clear_session() del module, ckpt new_module = TestNewModule(table_devices_) - new_ckpt = tf.train.Checkpoint(new_module) + new_ckpt = de.train.DECheckpoint(new_module) + restore_ckpt_path = tf.train.latest_checkpoint(model_dir) new_ckpt.read(restore_ckpt_path) self.assertEqual(new_module.size(), test_size) expected_values = new_module(keys) @@ -663,7 +663,7 @@ def size(self): shard_num = 5 table_devices_ = table_device * shard_num new_module = TestNewModule(table_devices_) - new_ckpt = tf.train.Checkpoint(new_module) + new_ckpt = de.train.DECheckpoint(new_module) new_ckpt.read(restore_ckpt_path) self.assertEqual(new_module.size(), test_size) expected_values = new_module(keys) @@ -675,7 +675,7 @@ def size(self): shard_num = 2 table_devices_ = table_device * shard_num new_module = TestNewModule(table_devices_) - new_ckpt = tf.train.Checkpoint(new_module) + new_ckpt = de.train.DECheckpoint(new_module) new_ckpt.read(restore_ckpt_path) self.assertEqual(new_module.size(), test_size) expected_values = new_module(keys) @@ -687,7 +687,7 @@ def size(self): shard_num = 1 table_devices_ = table_device * shard_num new_module = TestNewModule(table_devices_) - new_ckpt = tf.train.Checkpoint(new_module) + new_ckpt = de.train.DECheckpoint(new_module) new_ckpt.read(restore_ckpt_path) self.assertEqual(new_module.size(), test_size) expected_values = new_module(keys) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/train/__init__.py b/tensorflow_recommenders_addons/dynamic_embedding/python/train/__init__.py index 2e679d230..ad7f4620b 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/train/__init__.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/train/__init__.py @@ -1,2 +1,2 @@ from tensorflow_recommenders_addons.dynamic_embedding.python.train.saver import DEHvdSaver -from tensorflow_recommenders_addons.dynamic_embedding.python.train.checkpoint import DEHvdCheckpoint +from tensorflow_recommenders_addons.dynamic_embedding.python.train.checkpoint import DECheckpoint diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py b/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py index df49302a3..7092ff24c 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py @@ -20,17 +20,22 @@ from tensorflow_recommenders_addons import dynamic_embedding as de from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers import HvdAllToAllEmbedding from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import TrainableWrapper, DEResourceVariable +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.tf_save_restore_patch import de_fs_saveable_class_names, de_fs_sub_saveable_class_names +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op -try: +from tensorflow.python.framework import ops +try: # tf version >= 2.10.0 from tensorflow.python.checkpoint.checkpoint import Checkpoint + from tensorflow.python.checkpoint import restore as ckpt_base except: from tensorflow.python.training.tracking.util import Checkpoint + from tensorflow.python.training.tracking import base as ckpt_base from tensorflow.python.lib.io import file_io from tensorflow.python.platform import tf_logging -class DEHvdCheckpoint(Checkpoint): +class DECheckpoint(Checkpoint): """Overwrite tf.train.Saver class Calling the TF save API for all ranks causes file conflicts, so KV files other than rank0 need to be saved by calling the underlying API separately. @@ -65,103 +70,41 @@ def __init__(self, root=None, **kwargs): except: self._hvd = None + self._de_need_opt = False self._tmp_var_key_set = set({}) - for k, _ in sorted(kwargs.items(), key=lambda item: item[0]): + for k, v in sorted(kwargs.items(), key=lambda item: item[0]): + v_strname = str(v).lower() + if "optimizer" in v_strname: + self._de_need_opt = True self._tmp_var_key_set.add(k) - super(DEHvdCheckpoint, self).__init__(root, **kwargs) - - def _get_de_variable_folder_dir(self, - save_path: str, - global_step: str = None): - save_path_parent = os.path.dirname(save_path) - if global_step is not None: - de_variable_folder_dir = os.path.join( - save_path_parent, "TFRADynamicEmbedding-{}".format(global_step)) - else: - de_variable_folder_dir = os.path.join(save_path_parent, - "TFRADynamicEmbedding") - return de_variable_folder_dir - - def _delete_redundant_de_dir(self, ckpt_index_list: list): - if not len(ckpt_index_list) > 0: - return - save_path_parent = os.path.dirname(ckpt_index_list[0]) - de_dir_pattern = os.path.join(save_path_parent, "TFRADynamicEmbedding-*") - found_de_dir_set = set(file_io.get_matching_files(de_dir_pattern)) - keep_de_dir_set = set([]) - for file_path in ckpt_index_list: - global_step = file_path.split('.index')[-2].split('-')[-1] - de_dir = os.path.join(save_path_parent, - "TFRADynamicEmbedding-{}".format(global_step)) - keep_de_dir_set.add(de_dir) - delete_de_dir_set = found_de_dir_set - keep_de_dir_set - for de_dir in delete_de_dir_set: - if file_io.is_directory(de_dir): - file_io.delete_recursively(de_dir) + patch_tf_checkpoint() + super(DECheckpoint, self).__init__(root, **kwargs) def _de_var_fs_save_funtion(self, de_var, de_dir: str): - a2a_emb = de_var._created_in_class - hvd_size = 1 if self._hvd is None else self._hvd.size() - hvd_rank = 0 if self._hvd is None else self._hvd.rank() - if issubclass(a2a_emb.__class__, HvdAllToAllEmbedding): - if de_var._saveable_object_creator is None: - tf_logging.warning( - "Please use FileSystemSaver when use HvdAllToAllEmbedding. " - "It will allow TFRA load KV files when Embedding tensor parallel. " - f"The embedding shards at each horovod rank are now temporarily stored in {de_dir}" - ) - else: - # save Dynamic Embedding Parameters - de_var.save_to_file_system(dirpath=de_dir, - proc_size=hvd_size, - proc_rank=hvd_rank) - # save optimizer parameters of Dynamic Embedding - de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr( - a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars - for de_opt_var in de_opt_vars: - de_opt_var.save_to_file_system(dirpath=de_dir, - proc_size=hvd_size, - proc_rank=hvd_rank) - - def _de_var_fs_restore_funtion(self, de_var, de_dir: str): - a2a_emb = de_var._created_in_class hvd_size = 1 if self._hvd is None else self._hvd.size() hvd_rank = 0 if self._hvd is None else self._hvd.rank() - if issubclass(a2a_emb.__class__, HvdAllToAllEmbedding): - if de_var._saveable_object_creator is None: - tf_logging.warning( - "Please use FileSystemSaver when use HvdAllToAllEmbedding. " - "It will allow TFRA load KV files when Embedding tensor parallel. " - f"The embedding shards at each horovod rank are now temporarily stored in {de_dir}" - ) - else: - # restore Dynamic Embedding Parameters - de_var.load_from_file_system_with_restore_function(dirpath=de_dir, - proc_size=hvd_size, - proc_rank=hvd_rank) - # restore optimizer parameters of Dynamic Embedding - de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr( - a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars - for de_opt_var in de_opt_vars: - de_opt_var.load_from_file_system_with_restore_function( - dirpath=de_dir, proc_size=hvd_size, proc_rank=hvd_rank) + de_var.save_to_file_system(dirpath=de_dir, + proc_size=hvd_size, + proc_rank=hvd_rank) def _de_handle_root_and_var_with_func(self, de_dir: str, func): - def _filter_de_hvd_a2a_tw(var): + def _filter_de_tw(var): if not hasattr(var, "params") or not isinstance(var, TrainableWrapper): return False - if not hasattr(var.params, "_created_in_class"): + if not hasattr(var.params, "saveable"): + return False + if type(var.params.saveable).__name__ not in de_fs_saveable_class_names: return False return True def _handle_model_or_variable(obj): - if _filter_de_hvd_a2a_tw(obj): + if _filter_de_tw(obj): func(var.params, de_dir) if hasattr(obj, 'variables'): _iter = obj.variables() if callable(obj.variables) else obj.variables for var in _iter: - if _filter_de_hvd_a2a_tw(var): + if _filter_de_tw(var): func(var.params, de_dir) if hasattr(self, 'root'): @@ -171,33 +114,40 @@ def _handle_model_or_variable(obj): obj_var = getattr(self, obj_key) _handle_model_or_variable(obj_var) - def _de_hvd_write_fs_func(self, file_prefix, tf_write_func): - - def _get_de_dir_from_file_path(file_path): - file_prefix_split = file_path.split('-') - file_prefix_pattern = ''.join(file_prefix_split[0:-1]) - global_step = file_prefix_split[-1] - if not global_step.isdigit(): - global_step = None - de_dir = self._get_de_variable_folder_dir(file_path, global_step) - return file_prefix_pattern, global_step, de_dir - - def _rank0_delete_files_and_return_de_dir(file_path): - file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path( - file_path) - if global_step is not None: - ckpt_index_list = file_io.get_matching_files(file_prefix_pattern + - '-*.index') - self._delete_redundant_de_dir( - ckpt_index_list - ) # Compatible with automatic sweep function of checkpointmanager - return de_dir + def _redirect_new_de_dir(self, de_dir): + use_session = (not context.executing_eagerly() + and not ops.inside_function()) + if use_session: + if self._object_graph_feed_tensor is None: + with ops.device("/cpu:0"): + self._object_graph_feed_tensor = constant_op.constant( + "", dtype=dtypes.string) + object_graph_tensor = self._object_graph_feed_tensor + else: + object_graph_tensor = None + try: + if hasattr(self._saver, "_gather_saveables"): + #TODO: _gather_saveables return nothing when restore + named_saveable_objects, _, _, _ = self._saver._gather_saveables( + object_graph_tensor=object_graph_tensor) + elif hasattr(self._saver._graph_view, "serialize_object_graph"): + named_saveable_objects, _, _ = self._saver._graph_view.serialize_object_graph( + ) + except: + raise ( + "Can't find _gather_saveables or _graph_view.serialize_object_graph function at self._saver! " + "Unsupport TrackableSaver version!") + for saveable in named_saveable_objects: + if type(saveable).__name__ in de_fs_sub_saveable_class_names: + if hasattr(saveable, '_saver_config'): + saveable._saver_config.save_path = de_dir + def _de_hvd_write_fs_func(self, file_prefix, tf_write_func): + _, _, de_dir = _get_de_dir_from_file_path(file_prefix) + self._redirect_new_de_dir(de_dir) if self._hvd is None: file_path = tf_write_func() de_dir = _rank0_delete_files_and_return_de_dir(file_path) - self._de_handle_root_and_var_with_func(de_dir=de_dir, - func=self._de_var_fs_save_funtion) else: file_path = '' if self._hvd.rank() == 0: @@ -205,21 +155,23 @@ def _rank0_delete_files_and_return_de_dir(file_path): self._hvd.broadcast_object(file_path, root_rank=0, name='de_hvd_broadcast_file_path') - de_dir = _rank0_delete_files_and_return_de_dir(file_path) - self._hvd.join() # Sync for avoiding files conflict - self._de_handle_root_and_var_with_func( - de_dir=de_dir, func=self._de_var_fs_save_funtion) + _, _, tf_return_de_dir = _get_de_dir_from_file_path(file_path) + if tf_return_de_dir != de_dir: + self._de_handle_root_and_var_with_func( + de_dir=tf_return_de_dir, func=self._de_var_fs_save_funtion) self._hvd.join( ) # Sync for avoiding files conflict and rank finish early + de_dir = _rank0_delete_files_and_return_de_dir(file_path) + self._hvd.join() # Sync for avoiding files conflict else: file_path = self._hvd.broadcast_object( None, root_rank=0, name='de_hvd_broadcast_file_path') _, _, de_dir = _get_de_dir_from_file_path(file_path) - self._hvd.join() # Sync for avoiding files conflict self._de_handle_root_and_var_with_func( de_dir=de_dir, func=self._de_var_fs_save_funtion) self._hvd.join( ) # Sync for avoiding files conflict and rank finish early + self._hvd.join() # Sync for avoiding files conflict return file_path def _write(self, file_prefix, options=None, *args, **kwargs): @@ -238,10 +190,10 @@ def _write(self, file_prefix, options=None, *args, **kwargs): """ def tf_write_func_impl(): - return super(DEHvdCheckpoint, self)._write(file_prefix=file_prefix, - options=options, - *args, - **kwargs) + return super(DECheckpoint, self)._write(file_prefix=file_prefix, + options=options, + *args, + **kwargs) return self._de_hvd_write_fs_func(file_prefix=file_prefix, tf_write_func=tf_write_func_impl) @@ -258,16 +210,16 @@ def write(self, file_prefix, options=None, *args, **kwargs): """ def tf_write_func_impl(): - if hasattr(super(DEHvdCheckpoint, self), '_write'): - return super(DEHvdCheckpoint, self)._write(file_prefix=file_prefix, - options=options, - *args, - **kwargs) + if hasattr(super(DECheckpoint, self), '_write'): + return super(DECheckpoint, self)._write(file_prefix=file_prefix, + options=options, + *args, + **kwargs) else: - return super(DEHvdCheckpoint, self).write(file_prefix=file_prefix, - options=options, - *args, - **kwargs) + return super(DECheckpoint, self).write(file_prefix=file_prefix, + options=options, + *args, + **kwargs) return self._de_hvd_write_fs_func(file_prefix=file_prefix, tf_write_func=tf_write_func_impl) @@ -301,7 +253,7 @@ def restore(self, save_path, options=None, *args, **kwargs): global_step = save_path_split[-1] if not global_step.isdigit(): global_step = None - de_dir = self._get_de_variable_folder_dir(save_path, global_step) + de_dir = _get_de_variable_folder_dir(save_path, global_step) impl_save_path = save_path if 'TFRADynamicEmbedding' in save_path: @@ -316,7 +268,7 @@ def restore(self, save_path, options=None, *args, **kwargs): else: corresponding_ckpt_index = file_io.get_matching_files( os.path.join(os.path.dirname(save_path), '*.index')) - de_dir = self._get_de_variable_folder_dir( + de_dir = _get_de_variable_folder_dir( save_path, (corresponding_ckpt_index[0].split('-')[-1].split('.index')[0])) if len(corresponding_ckpt_index) > 0: @@ -326,16 +278,121 @@ def restore(self, save_path, options=None, *args, **kwargs): f'Arg save_path {save_path} is illegal or not existing. Now using index {impl_save_path}' ) - result = super(DEHvdCheckpoint, self).restore(save_path=impl_save_path, - options=options, - *args, - **kwargs) - if os.path.exists(de_dir): - self._de_handle_root_and_var_with_func( - de_dir=de_dir, func=self._de_var_fs_restore_funtion) - else: + self._redirect_new_de_dir(de_dir) + result = super(DECheckpoint, self).restore(save_path=impl_save_path, + options=options, + *args, + **kwargs) + + if not os.path.exists(de_dir): + tf_logging.warning( + f'TFRADynamicEmbedding directory {de_dir} is not existing.') + if self._hvd is not None: + self._hvd.join() # Sync for avoiding files conflict + return result + + def read(self, save_path, options=None, *args, **kwargs): + save_path_split = save_path.split('-') + save_path_pattern = ''.join(save_path_split[0:-1]) + global_step = save_path_split[-1] + if not global_step.isdigit(): + global_step = None + de_dir = _get_de_variable_folder_dir(save_path, global_step) + + self._redirect_new_de_dir(de_dir) + result = super(DECheckpoint, self).read(save_path=save_path, + options=options, + *args, + **kwargs) + + if not os.path.exists(de_dir): tf_logging.warning( f'TFRADynamicEmbedding directory {de_dir} is not existing.') if self._hvd is not None: self._hvd.join() # Sync for avoiding files conflict return result + + +def _get_de_variable_folder_dir(save_path: str, global_step: str = None): + save_path_parent = os.path.dirname(save_path) + if global_step is not None: + de_variable_folder_dir = os.path.join( + save_path_parent, "TFRADynamicEmbedding-{}".format(global_step)) + else: + de_variable_folder_dir = os.path.join(save_path_parent, + "TFRADynamicEmbedding") + return de_variable_folder_dir + + +def _delete_redundant_de_dir(ckpt_index_list: list): + if not len(ckpt_index_list) > 0: + return + save_path_parent = os.path.dirname(ckpt_index_list[0]) + de_dir_pattern = os.path.join(save_path_parent, "TFRADynamicEmbedding-*") + found_de_dir_set = set(file_io.get_matching_files(de_dir_pattern)) + keep_de_dir_set = set([]) + for file_path in ckpt_index_list: + global_step = file_path.split('.index')[-2].split('-')[-1] + de_dir = os.path.join(save_path_parent, + "TFRADynamicEmbedding-{}".format(global_step)) + keep_de_dir_set.add(de_dir) + delete_de_dir_set = found_de_dir_set - keep_de_dir_set + for de_dir in delete_de_dir_set: + if file_io.is_directory(de_dir): + file_io.delete_recursively(de_dir) + + +def _get_de_dir_from_file_path(file_path): + file_prefix_split = file_path.split('-') + file_prefix_pattern = ''.join(file_prefix_split[0:-1]) + global_step = file_prefix_split[-1] + if not global_step.isdigit(): + global_step = None + de_dir = _get_de_variable_folder_dir(file_path, global_step) + return file_prefix_pattern, global_step, de_dir + + +def _rank0_delete_files_and_return_de_dir(file_path): + file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path( + file_path) + if global_step is not None: + ckpt_index_list = file_io.get_matching_files(file_prefix_pattern + + '-*.index') + _delete_redundant_de_dir( + ckpt_index_list + ) # Compatible with automatic sweep function of checkpointmanager + return de_dir + + +def patch_tf_checkpoint(): + ckpt_base.CheckpointPosition = DECheckpointPosition + + +class DECheckpointPosition(ckpt_base.CheckpointPosition): + + def _redirect_new_de_dir(self, named_saveables, de_dir): + for saveable in named_saveables.values(): + if type(saveable).__name__ in de_fs_sub_saveable_class_names: + if hasattr(saveable, '_saver_config'): + saveable._saver_config.save_path = de_dir + + def _single_restoration_from_checkpoint_position(self, checkpoint_position, + visit_queue): + restore_ops, tensor_saveables, python_saveables = \ + super(DECheckpointPosition, self)._single_restoration_from_checkpoint_position( + checkpoint_position, visit_queue + ) + _, _, de_dir = _get_de_dir_from_file_path(self._checkpoint.save_path_string) + self._redirect_new_de_dir(named_saveables, de_dir) + return restore_ops, tensor_saveables, python_saveables + + def gather_ops_or_named_saveables(self): + result_tuple = super(DECheckpointPosition, + self).gather_ops_or_named_saveables() + named_saveables = result_tuple[1] + registered_savers = None + if len(result_tuple) == 4: + registered_savers = result_tuple[3] + _, _, de_dir = _get_de_dir_from_file_path(self._checkpoint.save_path_string) + self._redirect_new_de_dir(named_saveables, de_dir) + return result_tuple