Skip to content

Commit

Permalink
[fix] FileSystem saver didn't restore parameter properly when user cr…
Browse files Browse the repository at this point in the history
…eate their Keras model with lazy building. Also now fully support using CheckpointManager.
  • Loading branch information
MoFHeka committed Dec 15, 2023
1 parent 10f596c commit 7531a6a
Show file tree
Hide file tree
Showing 4 changed files with 368 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import itertools
import os
import numpy as np
import shutil

import tensorflow as tf
Expand All @@ -34,6 +35,7 @@
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adam
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import monitored_session
from tensorflow.python.training.optimizer import Optimizer as tf1_opt
from tensorflow.python.training import training_util
Expand Down Expand Up @@ -97,6 +99,8 @@ def test_all_to_all_embedding_trainable(self):
self.common_all_to_all_embedding_trainable_v2(keras_base_opt,
keras_test_opt,
name="keras_adam")
self.common_lazy_build_model_with_checkpoint_management_v2(
name="keras_adam_lazy_build")

def common_minimize_trainable_v1(self, base_opt, test_opt, name):
# TODO(rhdong): Recover the testing, if the horovod import error is fixed on macOS+TF2.7+.
Expand Down Expand Up @@ -326,20 +330,23 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name):
de.keras.models.de_hvd_save_model(base_model,
save_dir,
options=save_options)
ckpt = de.train.DEHvdCheckpoint(base_model)
ckpt = de.train.DECheckpoint(
my_model=base_model) # Test custom model key "my_model"
ckpt.save(save_dir + '/ckpt/test')
tf.keras.backend.clear_session()
del base_model
del base_opt
tf.keras.backend.clear_session()
new_opt = de.DynamicEmbeddingOptimizer(Adam(1.1), synchronous=True)
new_base_model = get_emb_sequential_model(
de.keras.layers.HvdAllToAllEmbedding,
base_opt,
new_opt,
dense_init='ones',
embedding_size=dim,
initializer=init,
bp_v2=False,
kv_creator=kv_creator,
name='all2all_emb')
ckpt = de.train.DEHvdCheckpoint(my_model=new_base_model)
ckpt = de.train.DECheckpoint(my_model=new_base_model)
hvd.join() # Sync for avoiding files conflict
ckpt.restore(tf.train.latest_checkpoint(save_dir + '/ckpt/'))
new_a2aemb_size = new_base_model.layers[0].params.size()
Expand All @@ -351,6 +358,165 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name):
self.assertEqual(a2aemb_size, new_a2aemb_size)
hvd.join() # Sync for avoiding files conflict

def common_lazy_build_model_with_checkpoint_management_v2(self, name):
# TODO(rhdong): Recover the testing, if the horovod import error is fixed on macOS+TF2.7+.
try:
import horovod.tensorflow as hvd
except (NotFoundError):
self.skipTest(
"Skip the test for horovod import error with Tensorflow-2.7.0 on MacOS-12."
)

tf.config.set_soft_device_placement(True)

hvd.init()

# These cases need 2 GPUs at least if available.
logical_devices = tf.config.list_logical_devices('GPU')
_device = "GPU" if len(logical_devices) >= hvd.size() else "CPU"
_device_id = hvd.local_rank(
) if _device == "GPU" and len(logical_devices) >= 2 else 0

if _device == "GPU":
os.environ["CUDA_VISIBLE_DEVICES"] = str(_device_id)

dim = 8

class NoCompileModel(tf.keras.models.Model):

def __init__(self, init, dynamic=False):
super().__init__(dynamic=dynamic)
kv_creator = de.CuckooHashTableCreator(saver=de.FileSystemSaver(
proc_size=hvd.size(), proc_rank=hvd.rank()))
self.emb = de.keras.layers.HvdAllToAllEmbedding(embedding_size=dim,
devices=['/GPU:0'],
initializer=0,
kv_creator=kv_creator,
name=name)
self.l1 = tf.keras.layers.Dense(8, 'relu', kernel_initializer=init)
self.l2 = tf.keras.layers.Dense(1, 'sigmoid', kernel_initializer=init)

def build(self, input_shape):
self.emb.build(input_shape)
self.l1.build(input_shape + dim)
self.l2.build(input_shape + 8)

def call(self, x):
out = self.emb(x)
out = self.l1(out)
return self.l2(out)

def check_TFRADynamicEmbedding_directory(save_dir,
save_it,
should_be_exist=True):
hvd_size = hvd.size()
if hvd_size <= 1:
hvd_size = 1
for tag in ['keys', 'values']:
for rank in range(hvd_size):
self.assertTrue(not (os.path.exists(
save_dir +
f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
) ^ should_be_exist))
self.assertTrue(not (os.path.exists(
save_dir +
f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
) ^ should_be_exist))
# f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
self.assertTrue(not (os.path.exists(
save_dir +
f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}'
) ^ should_be_exist))
# f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}'

with tf.device("/{}:{}".format(_device, _device_id)):
x = tf.reshape(tf.range(0, 32, dtype=tf.int64), [32, 1])
y = tf.random.uniform(shape=[32, 1])

save_dir = self.get_temp_dir()

model = NoCompileModel('ones')
base_opt = Adam(1.0)
base_opt = de.DynamicEmbeddingOptimizer(base_opt, synchronous=True)
ckpt = de.train.DECheckpoint(model=model, optimizer=base_opt)
model.compile(optimizer=base_opt, loss='mean_absolute_error')
manager = checkpoint_management.CheckpointManager(ckpt,
save_dir,
max_to_keep=1)
model.fit(x, y, verbose=0)
manager.save()
if hvd.rank() == 0:
check_TFRADynamicEmbedding_directory(save_dir,
save_it=1,
should_be_exist=True)
for l in model.layers:
if name in l.name:
l.params.upsert(x * 10, tf.random.uniform(shape=[32, 1, dim]))
emb_size = l.params.size()
emb_keys, emb_values = l.params.export()
break
for v in base_opt.variables():
if name in v.name:
v.params.upsert(x * 10, tf.random.uniform(shape=[32, 1, dim]))
opt_size = v.params.size()
opt_keys, opt_values = l.params.export()
break
manager.save()
if hvd.rank() == 0:
check_TFRADynamicEmbedding_directory(save_dir,
save_it=2,
should_be_exist=True)
# CheckpointManager delete checkpoint after the write functuon, but DE KV checkpoint saving and deleting inside the write functuon.
# So DE KV checkpoint TFRADynamicEmbedding directory will be always one more than TF checkpoint file.
manager.save()
if hvd.rank() == 0:
check_TFRADynamicEmbedding_directory(
save_dir, save_it=1, should_be_exist=False
) # Check delete TFRADynamicEmbedding directory properly.

del base_opt
del model
del ckpt
tf.keras.backend.clear_session()
tf.compat.v1.reset_default_graph()

new_model = NoCompileModel('zeros')
new_opt = Adam(1.1)
new_opt = de.DynamicEmbeddingOptimizer(new_opt, synchronous=True)
new_ckpt = de.train.DECheckpoint(model=new_model, optimizer=new_opt)
manager = checkpoint_management.CheckpointManager(new_ckpt,
save_dir,
max_to_keep=1)
manager.restore_or_initialize()
new_model.compile(optimizer=new_opt, loss='mean_absolute_error')
new_model(x) # Build vairiables
try:
new_opt._create_all_weights(new_model.variables)
except:
#TODO(MoFHejia) raise ValueError: Cannot convert a partially known TensorShape <unknown> to a Tensor.
pass
for l in new_model.layers:
if name in l.name:
new_emb_size = l.params.size()
new_emb_keys, new_emb_values = l.params.export()
break
for v in new_opt.variables():
if name in v.name:
new_opt_size = v.params.size()
new_opt_keys, new_opt_values = l.params.export()
break

self.assertEqual(emb_size, new_emb_size)
self.assertEqual(opt_size, new_opt_size)
self.assertAllEqual(np.sort(emb_keys, axis=0),
np.sort(new_emb_keys, axis=0))
self.assertAllClose(np.sort(emb_values, axis=0),
np.sort(new_emb_values, axis=0))
self.assertAllEqual(np.sort(opt_keys, axis=0),
np.sort(new_opt_keys, axis=0))
self.assertAllClose(np.sort(opt_values, axis=0),
np.sort(new_opt_values, axis=0))


if __name__ == "__main__":
test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -560,16 +560,16 @@ def size(self):

model_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
save_ckpt_dir = os.path.join(model_dir, 'model')
restore_ckpt_path = os.path.join(model_dir, 'model-1')

options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
ckpt = tf.train.Checkpoint(module)
ckpt = de.train.DECheckpoint(module)
ckpt.save(save_ckpt_dir)
shadow_value = module.shadow.read_value(False)
self.assertAllEqual(shadow_value.shape, (0, 2)) # clear when saving

new_module = TestModule()
new_ckpt = tf.train.Checkpoint(new_module)
new_ckpt = de.train.DECheckpoint(new_module)
restore_ckpt_path = tf.train.latest_checkpoint(model_dir)
new_ckpt.read(restore_ckpt_path)
self.assertEqual(new_module.size(), 3)
expected_values = module(keys)
Expand Down Expand Up @@ -640,18 +640,18 @@ def size(self):

model_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
save_ckpt_dir = os.path.join(model_dir, 'model')
restore_ckpt_path = os.path.join(model_dir, 'model-1')

options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
ckpt = tf.train.Checkpoint(module)
ckpt = de.train.DECheckpoint(module)
ckpt.save(save_ckpt_dir)
shadow_value = module.shadow.read_value(False)
self.assertAllEqual(shadow_value.shape, (0, 1)) # clear when saving

tf.keras.backend.clear_session()
del module, ckpt
new_module = TestNewModule(table_devices_)
new_ckpt = tf.train.Checkpoint(new_module)
new_ckpt = de.train.DECheckpoint(new_module)
restore_ckpt_path = tf.train.latest_checkpoint(model_dir)
new_ckpt.read(restore_ckpt_path)
self.assertEqual(new_module.size(), test_size)
expected_values = new_module(keys)
Expand All @@ -663,7 +663,7 @@ def size(self):
shard_num = 5
table_devices_ = table_device * shard_num
new_module = TestNewModule(table_devices_)
new_ckpt = tf.train.Checkpoint(new_module)
new_ckpt = de.train.DECheckpoint(new_module)
new_ckpt.read(restore_ckpt_path)
self.assertEqual(new_module.size(), test_size)
expected_values = new_module(keys)
Expand All @@ -675,7 +675,7 @@ def size(self):
shard_num = 2
table_devices_ = table_device * shard_num
new_module = TestNewModule(table_devices_)
new_ckpt = tf.train.Checkpoint(new_module)
new_ckpt = de.train.DECheckpoint(new_module)
new_ckpt.read(restore_ckpt_path)
self.assertEqual(new_module.size(), test_size)
expected_values = new_module(keys)
Expand All @@ -687,7 +687,7 @@ def size(self):
shard_num = 1
table_devices_ = table_device * shard_num
new_module = TestNewModule(table_devices_)
new_ckpt = tf.train.Checkpoint(new_module)
new_ckpt = de.train.DECheckpoint(new_module)
new_ckpt.read(restore_ckpt_path)
self.assertEqual(new_module.size(), test_size)
expected_values = new_module(keys)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from tensorflow_recommenders_addons.dynamic_embedding.python.train.saver import DEHvdSaver
from tensorflow_recommenders_addons.dynamic_embedding.python.train.checkpoint import DEHvdCheckpoint
from tensorflow_recommenders_addons.dynamic_embedding.python.train.checkpoint import DECheckpoint
Loading

0 comments on commit 7531a6a

Please sign in to comment.