From 9b9f394501760788cf4668f5eee4c95130cbfdbf Mon Sep 17 00:00:00 2001 From: Jia He Date: Thu, 21 Sep 2023 18:24:41 +0800 Subject: [PATCH] [feat] Make saving model more easier when using HvdAllToAllEmbedding by adding save function overwriting patch in tf_save_restore_patch.py. Also fix some import bug in tf_save_restore_patch.py. Also adding a save and restore test for HvdAllToAllEmbeeding. --- .../movielens-1m-keras-with-horovod.py | 548 +++++++++++++++--- .../python/keras/callbacks.py | 12 +- .../kernel_tests/horovod_sync_train_test.py | 33 ++ .../python/ops/tf_save_restore_patch.py | 119 +++- 4 files changed, 629 insertions(+), 83 deletions(-) diff --git a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py index 2385ad51d..167e5d8a5 100644 --- a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py +++ b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py @@ -1,4 +1,5 @@ import os +import shutil import tensorflow as tf import tensorflow_datasets as tfds @@ -40,35 +41,234 @@ FLAGS = flags.FLAGS input_spec = { - 'user_id': tf.TensorSpec(shape=[ - None, - ], dtype=tf.int64, name='user_id'), - 'movie_id': tf.TensorSpec(shape=[ - None, - ], dtype=tf.int64, name='movie_id') + 'user_id': + tf.TensorSpec(shape=[ + None, + 1, + ], dtype=tf.int64, name='user_id'), + 'user_gender': + tf.TensorSpec(shape=[ + None, + 1, + ], dtype=tf.int32, name='user_gender'), + 'user_occupation_label': + tf.TensorSpec(shape=[ + None, + 1, + ], + dtype=tf.int32, + name='user_occupation_label'), + 'raw_user_age': + tf.TensorSpec(shape=[ + None, + 1, + ], dtype=tf.int32, name='raw_user_age'), + 'movie_id': + tf.TensorSpec(shape=[ + None, + 1, + ], dtype=tf.int64, name='movie_id'), + 'movie_genres': + tf.TensorSpec(shape=[ + None, + 1, + ], dtype=tf.int32, name='movie_genres'), + 'timestamp': + tf.TensorSpec(shape=[ + None, + 1, + ], dtype=tf.int32, name='timestamp') } +feature_info_spec = { + 'movie_id': { + 'code': 101, + 'dtype': tf.int64, + 'dim': 1, + 'ptype': 'sparse_cpu', + 'input_tensor': None, + 'pretreated_tensor': None + }, + 'movie_genres': { + 'code': 102, + 'dtype': tf.int32, + 'dim': 1, + 'ptype': 'normal_gpu', + 'input_tensor': None, + 'pretreated_tensor': None, + }, + 'user_id': { + 'code': 103, + 'dtype': tf.int64, + 'dim': 1, + 'ptype': 'sparse_cpu', + 'input_tensor': None, + 'pretreated_tensor': None, + }, + 'user_gender': { + 'code': 104, + 'dtype': tf.int32, + 'dim': 1, + 'ptype': 'normal_gpu', + 'input_tensor': None, + 'pretreated_tensor': None, + }, + 'user_occupation_label': { + 'code': 105, + 'dtype': tf.int32, + 'dim': 1, + 'ptype': 'normal_gpu', + 'input_tensor': None, + 'pretreated_tensor': None, + }, + 'raw_user_age': { + 'code': 106, + 'dtype': tf.int32, + 'dim': 1, + 'ptype': 'normal_gpu', + 'input_tensor': None, + 'pretreated_tensor': None, + 'boundaries': [i for i in range(0, 100, 10)], + }, + 'timestamp': { + 'code': 107, + 'dtype': tf.int32, + 'dim': 1, + 'ptype': 'normal_gpu', + 'input_tensor': None, + 'pretreated_tensor': None, + } +} -# Construct input function -def input_fn(): - ds = tfds.load("movielens/1m-ratings", - split="train", - data_dir="/dataset", - download=False) - ids = ds.map( - lambda x: { - "movie_id": tf.strings.to_number(x["movie_id"], tf.int64), - "movie_genres": tf.cast(x["movie_genres"][0], tf.int32), - "user_id": tf.strings.to_number(x["user_id"], tf.int64), - "user_gender": tf.cast(x["user_gender"], tf.int32), - }) - ratings = ds.map(lambda x: {"user_rating": x["user_rating"]}) - dataset = tf.data.Dataset.zip((ids, ratings)) - shuffled = dataset.shuffle(1_000_000, - seed=2021, - reshuffle_each_iteration=False) - dataset = shuffled.repeat(1).batch(4096) - return dataset + +# Auxiliary function of GPU hash table combined query, recording which input is a vector feature embedding to be marked as a special treatment (usually an average) after embedding layer. +def embedding_inputs_concat(input_tensors, input_dims): + tmp_sum = 0 + input_split_dims = [] + input_is_sequence_feature = [] + for tensors, dim in zip(input_tensors, input_dims): + if tensors.get_shape().ndims != 2: + raise ("Please make sure dimension size of all input tensors is 2!") + if dim == 1: + tmp_sum = tmp_sum + 1 + elif dim > 1: + if tmp_sum > 0: + input_split_dims.append(tmp_sum) + input_is_sequence_feature.append(False) + input_split_dims.append(dim) + input_is_sequence_feature.append(True) + tmp_sum = 0 + else: + raise ("dim must >= 1, which is {}".format(dim)) + if tmp_sum > 0: + input_split_dims.append(tmp_sum) + input_is_sequence_feature.append(False) + input_tensors_concat = tf.keras.layers.Concatenate(axis=1)(input_tensors) + return (input_tensors_concat, input_split_dims, input_is_sequence_feature) + + +# After get the results of table combined query, we need to extract the vector features separately by split operator for a special treatment (usually an average). +def embedding_out_split(embedding_out_concat, input_split_dims): + embedding_out = list() + embedding_out.extend( + tf.split(embedding_out_concat, input_split_dims, + axis=1)) # (feature_combin_num, (batch, dim, emb_size)) + assert (len(input_split_dims) == len(embedding_out)) + return embedding_out + + +class ChannelEmbeddingLayers(): + + def __init__(self, + name='', + dense_embedding_size=1, + sparse_embedding_size=1, + embedding_initializer=tf.keras.initializers.Zeros(), + mpi_size=1, + mpi_rank=0): + + self.gpu_device = ["GPU:0"] + self.cpu_device = ["CPU:0"] + + # The saver parameter of kv_creator saves the K-V in the hash table into a separate KV file. + self.kv_creator = de.CuckooHashTableCreator( + saver=de.FileSystemSaver(proc_size=mpi_size, proc_rank=mpi_rank)) + + self.dense_embedding_layer = de.keras.layers.HvdAllToAllEmbedding( + mpi_size=mpi_size, + embedding_size=dense_embedding_size, + key_dtype=tf.int32, + value_dtype=tf.float32, + initializer=embedding_initializer, + devices=self.gpu_device, + name=name + '_DenseUnifiedEmbeddingLayer', + bp_v2=True, + init_capacity=4096000, + kv_creator=self.kv_creator) + + self.sparse_embedding_layer = de.keras.layers.HvdAllToAllEmbedding( + mpi_size=mpi_size, + embedding_size=sparse_embedding_size, + key_dtype=tf.int64, + value_dtype=tf.float32, + initializer=embedding_initializer, + devices=self.cpu_device, + name=name + '_SparseUnifiedEmbeddingLayer', + init_capacity=4096000, + kv_creator=self.kv_creator) + + self.dnn = tf.keras.layers.Dense( + 128, + activation='relu', + kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1), + bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1)) + + def __call__(self, features_info): + dense_inputs = [] + dense_input_dims = [] + sparse_inputs = [] + sparse_input_dims = [] + for fea_name, fea_info in features_info.items(): + # The features of GPU table and CPU table to be combined and queried are processed separately. + if fea_info['ptype'] == 'normal_gpu': + dense_inputs.append(fea_info['pretreated_tensor']) + dense_input_dims.append(fea_info['dim']) + elif fea_info['ptype'] == 'sparse_cpu': + sparse_inputs.append(fea_info['pretreated_tensor']) + sparse_input_dims.append(fea_info['dim']) + else: + ptype = fea_info['ptype'] + raise NotImplementedError(f'Not support ptype {ptype}.') + # The GPU table combined query starts + dense_input_tensors_concat, dense_input_split_dims, dense_input_is_sequence_feature = \ + embedding_inputs_concat(dense_inputs, dense_input_dims) + dense_emb_concat = self.dense_embedding_layer(dense_input_tensors_concat) + # The CPU table combined query starts + sparse_input_tensors_concat, sparse_input_split_dims, sparse_input_is_sequence_feature = \ + embedding_inputs_concat(sparse_inputs, sparse_input_dims) + sparse_emb_concat = self.sparse_embedding_layer(sparse_input_tensors_concat) + # Slice the combined query result + dense_emb_outs = embedding_out_split(dense_emb_concat, + dense_input_split_dims) + sparse_emb_outs = embedding_out_split(sparse_emb_concat, + sparse_input_split_dims) + # Process the results of the combined query after slicing. + embedding_outs = [] + input_is_sequence_feature = dense_input_is_sequence_feature + sparse_input_is_sequence_feature + for i, embedding in enumerate(dense_emb_outs + sparse_emb_outs): + if input_is_sequence_feature[i] == True: + # Deal with the embedding from vector features. + embedding_vec = tf.math.reduce_mean( + embedding, axis=1, + keepdims=True) # (feature_combin_num, (batch, x, emb_size)) + else: + embedding_vec = embedding + embedding_vec = tf.keras.layers.Flatten()(embedding_vec) + embedding_outs.append(embedding_vec) + # Final embedding result. + embeddings_concat = tf.keras.layers.Concatenate(axis=1)(embedding_outs) + + return self.dnn(embeddings_concat) class DualChannelsDeepModel(tf.keras.Model): @@ -77,26 +277,37 @@ def __init__(self, user_embedding_size=1, movie_embedding_size=1, embedding_initializer=None, - is_training=True): - - if not is_training: + is_training=True, + mpi_size=1, + mpi_rank=0): + + if is_training: + de.enable_train_mode() + if embedding_initializer is None: + embedding_initializer = tf.keras.initializers.VarianceScaling() + else: de.enable_inference_mode() + if embedding_initializer is None: + embedding_initializer = tf.keras.initializers.Zeros() super(DualChannelsDeepModel, self).__init__() self.user_embedding_size = user_embedding_size self.movie_embedding_size = movie_embedding_size - if embedding_initializer is None: - embedding_initializer = tf.keras.initializers.Zeros() - - self.user_embedding = de.keras.layers.SquashedEmbedding( - user_embedding_size, - initializer=embedding_initializer, - name='user_embedding') - self.movie_embedding = de.keras.layers.SquashedEmbedding( - movie_embedding_size, - initializer=embedding_initializer, - name='movie_embedding') + self.user_embedding = ChannelEmbeddingLayers( + name='user', + dense_embedding_size=user_embedding_size, + user_embedding_size=user_embedding_size * 2, + embedding_initializer=embedding_initializer, + mpi_size=mpi_size, + mpi_rank=mpi_rank) + self.movie_embedding = ChannelEmbeddingLayers( + name='movie', + dense_embedding_size=movie_embedding_size, + user_embedding_size=movie_embedding_size * 2, + embedding_initializer=embedding_initializer, + mpi_size=mpi_size, + mpi_rank=mpi_rank) self.dnn1 = tf.keras.layers.Dense( 64, @@ -121,10 +332,44 @@ def __init__(self, @tf.function def call(self, features): - user_id = tf.reshape(features['user_id'], (-1, 1)) - movie_id = tf.reshape(features['movie_id'], (-1, 1)) - user_latent = self.user_embedding(user_id) - movie_latent = self.movie_embedding(movie_id) + # Construct input layers + for fea_name in features.keys(): + fea_info = feature_info_spec[fea_name] + input_tensor = tf.keras.layers.Input(shape=(fea_info['dim'],), + dtype=fea_info['dtype'], + name=fea_name) + fea_info['input_tensor'] = input_tensor + if fea_info.__contains__('boundaries'): + input_tensor = tf.raw_ops.Bucketize(input=input_tensor, + boundaries=fea_info['boundaries']) + # To prepare for GPU table combined queries, use a prefix to distinguish different features in a table. + if fea_info['ptype'] == 'normal_gpu': + if fea_info['dtype'] == tf.int64: + input_tensor_prefix_code = int(fea_info['code']) << 17 + elif fea_info['dtype'] == tf.int32: + input_tensor_prefix_code = int(fea_info['code']) << 14 + else: + input_tensor_prefix_code = None + if input_tensor_prefix_code is not None: + # input_tensor = tf.bitwise.bitwise_xor(input_tensor, input_tensor_prefix_code) + # xor operation can be replaced with addition operation to facilitate subsequent optimization of TRT and OpenVino. + input_tensor = tf.add(input_tensor, input_tensor_prefix_code) + fea_info['pretreated_tensor'] = input_tensor + + user_fea = ['user_id', 'user_gender', 'user_occupation_label'] + user_fea_info = { + key: value + for key, value in feature_info_spec.items() + if key in user_fea + } + user_latent = self.user_embedding(user_fea_info) + movie_fea = ['movie_id', 'movie_genres', 'user_occupation_label'] + movie_fea_info = { + key: value + for key, value in feature_info_spec.items() + if key in movie_fea + } + movie_latent = self.movie_embedding(movie_fea_info) latent = tf.concat([user_latent, movie_latent], axis=1) x = self.dnn1(latent) @@ -137,26 +382,161 @@ def call(self, features): def get_dataset(batch_size=1): - dataset = tfds.load('movielens/1m-ratings', split='train') - features = dataset.map( + ds = tfds.load("movielens/1m-ratings", + split="train", + data_dir="/dataset", + download=False) + features = ds.map( lambda x: { - "movie_id": tf.strings.to_number(x["movie_id"], tf.int64), - "user_id": tf.strings.to_number(x["user_id"], tf.int64), + "movie_id": + tf.strings.to_number(x["movie_id"], tf.int64), + "movie_genres": + tf.cast(x["movie_genres"][0], tf.int32), + "user_id": + tf.strings.to_number(x["user_id"], tf.int64), + "user_gender": + tf.cast(x["user_gender"], tf.int32), + "user_occupation_label": + tf.cast(x["user_occupation_label"], tf.int32), + "raw_user_age": + tf.cast(x["raw_user_age"], tf.int32), + "timestamp": + tf.cast(x["timestamp"] - 880000000, tf.int32), }) - ratings = dataset.map( - lambda x: tf.one_hot(tf.cast(x['user_rating'] - 1, dtype=tf.int64), 5)) - dataset = dataset.zip((features, ratings)) - dataset = dataset.shuffle(4096, reshuffle_each_iteration=False) - if batch_size > 1: - dataset = dataset.batch(batch_size) + ratings = ds.map(lambda x: {"user_rating": x["user_rating"]}) + dataset = tf.data.Dataset.zip((features, ratings)) + shuffled = dataset.shuffle(1_000_000, + seed=2021, + reshuffle_each_iteration=False) + dataset = shuffled.repeat(1).batch(batch_size).prefetch(tf.data.AUTOTUNE) + # Only GPU:0 since TF is set to be visible to GPU:X + dataset = dataset.apply( + tf.data.experimental.prefetch_to_device('GPU:0', buffer_size=2)) return dataset +def export_to_savedmodel(model, savedmodel_dir): + save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) + + if not os.path.exists(savedmodel_dir): + os.mkdir(savedmodel_dir) + + ########################## What really happened ########################## + # # Calling the TF save API for all ranks causes file conflicts, so KV files other than rank0 need to be saved by calling the underlying API separately. + # if hvd.rank() == 0: + # tf.keras.models.save_model(model, + # savedmodel_dir, + # overwrite=True, + # include_optimizer=True, + # save_traces=True, + # options=save_options) + # else: + # de_dir = os.path.join(savedmodel_dir, "variables", "TFRADynamicEmbedding") + # for layer in model.layers: + # if hasattr(layer, "params"): + # # Save embedding parameters + # layer.params.save_to_file_system(dirpath=de_dir, + # proc_size=hvd.size(), + # proc_rank=hvd.rank()) + # # Save the optimizer parameters + # opt_de_vars = layer.optimizer_vars.as_list() if hasattr( + # layer.optimizer_vars, "as_list") else layer.optimizer_vars + # for opt_de_var in opt_de_vars: + # opt_de_var.save_to_file_system(dirpath=de_dir, + # proc_size=hvd.size(), + # proc_rank=hvd.rank()) + + # TFRA modify the Keras save function with a monkey patch. + # !!!! Run save_model function in all rank !!!! + tf.keras.models.save_model(model, + savedmodel_dir, + overwrite=True, + include_optimizer=True, + save_traces=True, + options=save_options) + + +def export_for_serving(model, export_dir): + save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) + + if not os.path.exists(export_dir): + os.mkdir(export_dir) + + def save_spec(): + if hasattr(model, 'save_spec'): + # tf version >= 2.6 + return model.save_spec() + else: + arg_specs = list() + kwarg_specs = dict() + for i in model.inputs: + arg_specs.append(i.type_spec) + return [arg_specs], kwarg_specs + + @tf.function + def serve(*args, **kwargs): + return model(*args, **kwargs) + + arg_specs, kwarg_specs = save_spec() + + ########################## What really happened ########################## + # if hvd.rank() == 0: + # # Remember to remove optimizer parameters when ready to serve. + # tf.keras.models.save_model( + # model, + # export_dir, + # overwrite=True, + # include_optimizer=False, + # options=save_options, + # signatures={ + # 'serving_default': + # serve.get_concrete_function(*arg_specs, **kwarg_specs) + # }, + # ) + # else: + # de_dir = os.path.join(export_dir, "variables", "TFRADynamicEmbedding") + # for layer in model.layers: + # if hasattr(layer, "params"): + # layer.params.save_to_file_system(dirpath=de_dir, + # proc_size=hvd.size(), + # proc_rank=hvd.rank()) + + # TFRA modify the Keras save function with a monkey patch. + # !!!! Run save_model function in all rank !!!! + tf.keras.models.save_model( + model, + export_dir, + overwrite=True, + include_optimizer=False, + options=save_options, + signatures={ + 'serving_default': + serve.get_concrete_function(*arg_specs, **kwarg_specs) + }, + ) + + if hvd.rank() == 0: + # Modify the inference graph to a stand-alone version + from tensorflow.python.saved_model import save as tf_save + tf.keras.backend.clear_session() + de.enable_inference_mode() + export_model = DualChannelsDeepModel(FLAGS.embedding_size, + FLAGS.embedding_size, + tf.keras.initializers.Zeros(), + hvd.size(), hvd.rank()) + # The save_and_return_nodes function is used to overwrite the saved_model.pb file generated by the save_model function and rewrite the inference graph. + tf_save.save_and_return_nodes(obj=export_model, + export_dir=export_dir, + options=save_options, + experimental_skip_checkpoint=True) + + def train(): dataset = get_dataset(batch_size=32) model = DualChannelsDeepModel(FLAGS.embedding_size, FLAGS.embedding_size, - tf.keras.initializers.RandomNormal(0.0, 0.5)) + tf.keras.initializers.RandomNormal(0.0, 0.5), + hvd.size(), hvd.rank()) optimizer = tf.keras.optimizers.Adam(1E-3) optimizer = de.DynamicEmbeddingOptimizer(optimizer) @@ -170,34 +550,46 @@ def train(): if os.path.exists(FLAGS.model_dir): model.load_weights(FLAGS.model_dir) - model.fit(dataset, epochs=FLAGS.epochs, steps_per_epoch=FLAGS.steps_per_epoch) - + tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.model_dir) save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) - model.save(FLAGS.model_dir, options=save_options) + # horovod callback is used to broadcast the value generated by initializer of rank0. + hvd_opt_init_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback( + root_rank=0) + callbacks_list = [hvd_opt_init_callback, ckpt_callback] + # The log class callback only takes effect in rank0 for convenience + if hvd.rank() == 0: + callbacks_list.extend([tensorboard_callback]) + # If there are callbacks such as evaluation metrics that call model calculations, take effect on all ranks. + # callbacks_list.extend([my_auc_callback]) + + model.fit(dataset, + callbacks=callbacks_list, + epochs=FLAGS.epochs, + steps_per_epoch=FLAGS.steps_per_epoch, + verbose=1 if hvd.rank() == 0 else 0) + + export_to_savedmodel(model, FLAGS.model_dir) + export_for_serving(model, FLAGS.export_dir) def export(): - model = DualChannelsDeepModel(FLAGS.embedding_size, FLAGS.embedding_size, - tf.keras.initializers.Zeros(), False) - model.load_weights(FLAGS.model_dir) - - # Build input spec with dummy data. If the model is built with explicit - # input specs, then no need of dummy data. - dummy_data = { - 'user_id': tf.zeros((16,), dtype=tf.int64), - 'movie_id': tf.zeros([ - 16, - ], dtype=tf.int64) - } - model(dummy_data) - + de.enable_inference_mode() + if not os.path.exists(FLAGS.export_dir): + shutil.copytree(FLAGS.model_dir, FLAGS.export_dir) + export_model = DualChannelsDeepModel(FLAGS.embedding_size, + FLAGS.embedding_size, + tf.keras.initializers.RandomNormal( + 0.0, 0.5), + mpi_size=1, + mpi_rank=0) save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) - tf.keras.models.save_model( - model, - FLAGS.export_dir, - options=save_options, - include_optimizer=False, - signatures=model.call.get_concrete_function(input_spec)) + # Modify the inference graph to a stand-alone version + from tensorflow.python.saved_model import save as tf_save + # The save_and_return_nodes function is used to overwrite the saved_model.pb file generated by the save_model function and rewrite the inference graph. + tf_save.save_and_return_nodes(obj=export_model, + export_dir=FLAGS.export_dir, + options=save_options, + experimental_skip_checkpoint=True) def test(): diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py index 569980e2f..c85ec85f6 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py @@ -23,7 +23,9 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated +from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers import HvdAllToAllEmbedding from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import TrainableWrapper, DEResourceVariable from tensorflow_recommenders_addons.utils.check_platform import is_macos, is_arm64 @@ -106,6 +108,10 @@ def __init__(self, root_rank=0, device='', local_variables=None): self.register_local_var(var) +@deprecated( + None, "\n!!!! Using this callback will cause a save twice error. !!!!\n" + "The callbacks.ModelCheckpoint for HvdAllToAllEmbedding has been deprecated, use original ModelCheckpoint instead.\n" + "!!!! Using this callback will cause a save twice error. !!!!\n") class DEHvdModelCheckpoint(callbacks.ModelCheckpoint): def __init__(self, *args, **kwargs): @@ -124,8 +130,12 @@ def _save_de_model(self, filepath): else: de_dir = os.path.join(filepath, "variables", "TFRADynamicEmbedding") for layer in self.model.layers: - if hasattr(layer, "params"): + if hasattr(layer, "params") and isinstance(layer, HvdAllToAllEmbedding): # save Dynamic Embedding Parameters + logging.warning( + "!!!! Using this callback will cause a save twice error. !!!!\n" + "The callbacks.ModelCheckpoint for HvdAllToAllEmbedding has been deprecated, use original ModelCheckpoint instead.\n" + "!!!! Using this callback will cause a save twice error. !!!!\n") layer.params.save_to_file_system(dirpath=de_dir, proc_size=hvd.size(), proc_rank=hvd.rank()) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py index 0b8b2714d..340b24d9a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py @@ -20,6 +20,8 @@ import itertools import os +import shutil + import tensorflow as tf from tensorflow_recommenders_addons import dynamic_embedding as de @@ -259,6 +261,8 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name): base_opt = de.DynamicEmbeddingOptimizer(base_opt, synchronous=True) test_opt = hvd.DistributedOptimizer(test_opt) init = tf.keras.initializers.Zeros() + kv_creator = de.CuckooHashTableCreator( + saver=de.FileSystemSaver(proc_size=hvd.size(), proc_rank=hvd.rank())) batch_size = 8 start = 0 for dtype, run_step, dim in itertools.product([dtypes.float32], [10], [10]): @@ -276,6 +280,7 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name): embedding_size=dim, initializer=init, bp_v2=False, + kv_creator=kv_creator, name='all2all_emb') test_model = get_emb_sequential_model(tf.keras.layers.Embedding, test_opt, @@ -300,6 +305,34 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name): msg="Cond:{},{},{}".format(dtype, run_step, dim), ) + a2aemb_size = base_model.layers[0].params.size() + save_dir = "/tmp/hvd_save_restore" + str( + hvd.size()) + str(run_step) + str( + dim) # All ranks should share same save directory + save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) + if hvd.rank() == 0: + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + hvd.broadcast(tensor=tf.constant(1), + root_rank=0) # Sync for avoiding files conflict + base_model.save(save_dir, options=save_options) + del base_model + new_base_model = get_emb_sequential_model( + de.keras.layers.HvdAllToAllEmbedding, + base_opt, + embedding_size=dim, + initializer=init, + bp_v2=False, + kv_creator=kv_creator, + name='all2all_emb') + hvd.broadcast(tensor=tf.constant(1), + root_rank=0) # Sync for avoiding files conflict + new_base_model.load_weights(save_dir + '/variables/variables') + new_a2aemb_size = new_base_model.layers[0].params.size() + self.assertEqual(a2aemb_size, new_a2aemb_size) + hvd.broadcast(tensor=tf.constant(1), + root_rank=0) # Sync for avoiding files conflict + if __name__ == "__main__": test.main() diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py index 12abf30bf..c23a4024c 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py @@ -15,6 +15,7 @@ # lint-as: python3 """patch on tensorflow""" +import functools import os.path import re @@ -22,6 +23,10 @@ from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_variable \ import load_de_variable_from_file_system +try: + from keras.saving.saved_model import save as keras_saved_model_save +except: + keras_saved_model_save = None from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session from tensorflow.python.eager import context @@ -29,6 +34,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.keras.saving.saved_model import save as tf_saved_model_save from tensorflow.python.keras.utils import tf_utils from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops @@ -36,8 +42,10 @@ from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging from tensorflow.python.training import saver +from tensorflow.python.training import training_util try: # tf version >= 2.10.0 from tensorflow.python.checkpoint import checkpoint_management from tensorflow.python.checkpoint import checkpoint_options @@ -49,6 +57,10 @@ from tensorflow.python.util import compat from tensorflow.python.util import nest +tf_original_save_func = tf_saved_model_save.save +if keras_saved_model_save is not None: + keras_original_save_func = keras_saved_model_save.save + de_fs_saveable_class_names = [ '_DynamicEmbeddingVariabelFileSystemSaveable', ] @@ -502,12 +514,12 @@ def restore(self, sess, save_path): # 1. The checkpoint would not be loaded successfully as is. Try to parse # it as an object-based checkpoint. try: - names_to_keys = object_graph_key_mapping(save_path) + names_to_keys = saver.object_graph_key_mapping(save_path) except errors.NotFoundError: # 2. This is not an object-based checkpoint, which likely means there # is a graph mismatch. Re-raise the original error with # a helpful message (b/110263146) - raise _wrap_restore_error_with_msg( + raise saver._wrap_restore_error_with_msg( err, "a Variable name or other graph key that is missing") # This is an object-based checkpoint. We'll print a warning and then do @@ -517,7 +529,7 @@ def restore(self, sess, save_path): "may be somewhat fragile, and will re-build the Saver. Instead, " "consider loading object-based checkpoints using " "tf.train.Checkpoint().") - self._object_restore_saver = saver_from_object_based_checkpoint( + self._object_restore_saver = saver.saver_from_object_based_checkpoint( checkpoint_path=save_path, var_list=self._var_list, builder=self._builder, @@ -527,10 +539,104 @@ def restore(self, sess, save_path): except errors.InvalidArgumentError as err: # There is a mismatch between the graph and the checkpoint being loaded. # We add a more reasonable error message here to help users (b/110263146) - raise _wrap_restore_error_with_msg( + raise saver._wrap_restore_error_with_msg( err, "a mismatch between the current graph and the graph") +def _de_keras_save_func(original_save_func, + model, + filepath, + overwrite, + include_optimizer, + signatures=None, + options=None, + save_traces=True, + *args, + **kwargs): + """Overwrite TF Keras save function + Calling the TF save API for all ranks causes file conflicts, + so KV files other than rank0 need to be saved by calling the underlying API separately. + This is a convenience function for saving HvdAllToAllEmbedding to KV files in different rank. + + Args: + original_save_func: A handle for original save function. It could be from Keras or Tensorflow. + model: Keras model instance to be saved. + filepath: String path to save the model. + overwrite: whether to overwrite the existing filepath. + include_optimizer: If True, save the model's optimizer state. + signatures: Signatures to save with the SavedModel. Applicable to the 'tf' + format only. Please see the `signatures` argument in `tf.saved_model.save` + for details. + options: (only applies to SavedModel format) `tf.saved_model.SaveOptions` + object that specifies options for saving to SavedModel. + save_traces: (only applies to SavedModel format) When enabled, the + SavedModel will store the function traces for each layer. This + can be disabled, so that only the configs of each layer are stored. + Defaults to `True`. Disabling this will decrease serialization time + and reduce file size, but it requires that all custom layers/models + implement a `get_config()` method. + + Raises: + ValueError: if the model's inputs have not been defined. + """ + try: + import horovod.tensorflow as hvd + hvd.rank() + except: + hvd = None + + call_original_save_func = functools.partial( + original_save_func, + model=model, + filepath=filepath, + overwrite=overwrite, + include_optimizer=include_optimizer, + signatures=signatures, + options=options, + save_traces=save_traces, + *args, + **kwargs) + + def _traverse_emb_layers_and_save(hvd_rank): + de_dir = os.path.join(filepath, "variables", "TFRADynamicEmbedding") + for layer in model.layers: + if hasattr(layer, "params") and isinstance( + layer, de.keras.layers.HvdAllToAllEmbedding): + if layer.params._saveable_object_creator is None: + if hvd_rank == 0: + tf_logging.warning( + "Please use FileSystemSaver when use HvdAllToAllEmbedding. " + "It will allow TFRA load KV files when Embedding tensor parallel. " + f"The embedding shards at each horovod rank are now temporarily stored in {de_dir}" + ) + else: + if not isinstance(layer.params.kv_creator.saver, de.FileSystemSaver): + # This function only serves FileSystemSaver. + continue + if hvd_rank == 0: + # FileSystemSaver works well at rank 0. + continue + # Save embedding parameters + layer.params.save_to_file_system(dirpath=de_dir, + proc_size=hvd.size(), + proc_rank=hvd.rank()) + # Save the optimizer parameters + if include_optimizer is True: + opt_de_vars = layer.optimizer_vars.as_list() if hasattr( + layer.optimizer_vars, "as_list") else layer.optimizer_vars + for opt_de_var in opt_de_vars: + opt_de_var.save_to_file_system(dirpath=de_dir, + proc_size=hvd.size(), + proc_rank=hvd.rank()) + + if hvd is None: + call_original_save_func() + else: + if hvd.rank() == 0: + call_original_save_func() + _traverse_emb_layers_and_save(hvd.rank()) + + def patch_on_tf_save_restore(): try: from tensorflow.python.saved_model.registration.registration import register_checkpoint_saver @@ -545,3 +651,8 @@ def patch_on_tf_save_restore(): except: functional_saver._SingleDeviceSaver = _DynamicEmbeddingSingleDeviceSaver saver.Saver = _DynamicEmbeddingSaver + tf_saved_model_save.save = functools.partial(_de_keras_save_func, + tf_original_save_func) + if keras_saved_model_save is not None: + keras_saved_model_save.save = functools.partial(_de_keras_save_func, + keras_original_save_func)