Skip to content

Commit

Permalink
[fix] Rename DECheckpoint and de_save_model api to make them more loo…
Browse files Browse the repository at this point in the history
…ks like TF style.
  • Loading branch information
MoFHeka committed Jan 30, 2024
1 parent 525f3a3 commit c689540
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,12 @@ def export_to_savedmodel(model, savedmodel_dir):

# TFRA modify the Keras save function with a patch.
# !!!! Run save_model function in all rank !!!!
de.keras.models.de_save_model(model,
savedmodel_dir,
overwrite=True,
include_optimizer=True,
save_traces=True,
options=save_options)
de.keras.models.save_model(model,
savedmodel_dir,
overwrite=True,
include_optimizer=True,
save_traces=True,
options=save_options)


def export_for_serving(model, export_dir):
Expand Down Expand Up @@ -521,7 +521,7 @@ def serve(*args, **kwargs):

# TFRA modify the Keras save function with a patch.
# !!!! Run save_model function in all rank !!!!
de.keras.models.de_save_model(
de.keras.models.save_model(
model,
export_dir,
overwrite=True,
Expand Down Expand Up @@ -572,7 +572,7 @@ def train():
# horovod callback is used to broadcast the value generated by initializer of rank0.
hvd_opt_init_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback(
root_rank=0)
ckpt_callback = de.keras.callbacks.DEHvdModelCheckpoint(
ckpt_callback = de.keras.callbacks.ModelCheckpoint(
filepath=FLAGS.model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}',
options=save_options)
callbacks_list = [hvd_opt_init_callback, ckpt_callback]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ In addition, we also provide parameter initialization and save callback related

[`dynamic_embedding.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py)

[`dynamic_embedding.keras.callbacks.DEHvdModelCheckpoint`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py)
[`dynamic_embedding.keras.callbacks.ModelCheckpoint.`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py)

[`dynamic_embedding.keras.models.de_save_model`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py)
[`dynamic_embedding.keras.models.save_model`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py)

[`dynamic_embedding.train.DEHvdModelCheckpoint`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py)
[`dynamic_embedding.train.ModelCheckpoint.`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py)

You could inherit the `HvdAllToAllEmbedding` class to implement a custom embedding
layer with other fixed shape output.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import layers
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import callbacks
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import models
from tensorflow_recommenders_addons.dynamic_embedding.python.keras import models

setattr(models, 'save_model', models.de_save_model)
setattr(callbacks, 'ModelCheckpoint', callbacks.DEHvdModelCheckpoint)
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,7 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name):
shutil.rmtree(save_dir)
hvd.join() # Sync for avoiding files conflict
# base_model.save(save_dir, options=save_options)
de.keras.models.de_save_model(base_model,
save_dir,
options=save_options)
de.keras.models.save_model(base_model, save_dir, options=save_options)
ckpt = de.train.DECheckpoint(
my_model=base_model) # Test custom model key "my_model"
ckpt.save(save_dir + '/ckpt/test')
Expand Down Expand Up @@ -540,7 +538,7 @@ def check_TFRADynamicEmbedding_directory(save_dir,
np.sort(new_de_opt_compared[opt_v_name][2], axis=0))

extra_save_dir = self.get_temp_dir() + '/extra_save_dir'
de.keras.models.de_save_model(new_model, extra_save_dir)
de.keras.models.save_model(new_model, extra_save_dir)
if hvd.rank() == 0:
check_TFRADynamicEmbedding_directory(extra_save_dir)
del new_opt
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from tensorflow_recommenders_addons.dynamic_embedding.python.train.saver import DEHvdSaver
from tensorflow_recommenders_addons.dynamic_embedding.python.train.checkpoint import DECheckpoint

Checkpoint = DECheckpoint
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
try: # tf version >= 2.10.0
from tensorflow.python.checkpoint.checkpoint import Checkpoint
from tensorflow.python.checkpoint.checkpoint import Checkpoint as TFCheckpoint
from tensorflow.python.checkpoint import restore as ckpt_base
except:
from tensorflow.python.training.tracking.util import Checkpoint
from tensorflow.python.training.tracking.util import Checkpoint as TFCheckpoint
from tensorflow.python.training.tracking import base as ckpt_base
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import tf_logging


class DECheckpoint(Checkpoint):
class DECheckpoint(TFCheckpoint):
"""Overwrite tf.train.Saver class
Calling the TF save API for all ranks causes file conflicts,
so KV files other than rank0 need to be saved by calling the underlying API separately.
Expand Down

0 comments on commit c689540

Please sign in to comment.