Skip to content

Commit

Permalink
[fix] fix device of table shard was lost in savedmodel pb file when s…
Browse files Browse the repository at this point in the history
…ave model with keras api.
  • Loading branch information
MoFHeka committed Nov 1, 2022
1 parent f640888 commit 77500c7
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import itertools
import tensorflow as tf
Expand All @@ -27,6 +28,7 @@
from tensorflow_recommenders_addons import dynamic_embedding as de

from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
Expand Down Expand Up @@ -72,6 +74,8 @@ def test_create(self):
key_dtypes = [dtypes.int64]

value_dtypes = [dtypes.float32, dtypes.float64]
if test_util.is_gpu_available():
value_dtypes = [dtypes.float32]
initializers = [
tf.keras.initializers.RandomNormal(),
tf.keras.initializers.RandomUniform()
Expand Down Expand Up @@ -174,6 +178,60 @@ def test_backward_bp_v2(self):
model.fit(x, y, verbose=0)
self.assertAllEqual(emb_layer.params.size(), start)

def test_keras_save_load_weights(self):
if not context.executing_eagerly():
self.skipTest('Only test in eager mode')
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

def model_fn(table_device):
input_tensor = tf.keras.layers.Input(shape=(1,), dtype=tf.int64)
embedding_out = de.keras.layers.BasicEmbedding(
embedding_size=1,
key_dtype=tf.int64,
value_dtype=tf.float32,
initializer=tf.keras.initializers.RandomNormal(),
devices=table_device,
name='test_keras_save_restore',
)(input_tensor)
model = tf.keras.Model(inputs=input_tensor, outputs=embedding_out)
optimizer = tf.keras.optimizers.Adam(learning_rate=1E-4, amsgrad=False)
optimizer = de.DynamicEmbeddingOptimizer(optimizer)
model.compile(optimizer=optimizer)
return model

table_device_ = ['/device:CPU:0']
if test_util.is_gpu_available():
table_device_ = ['/device:GPU:0']
model = model_fn(table_device_)
params_ = model.get_layer('test_keras_save_restore').params
params_.upsert(
constant_op.constant([0, 1], dtypes.int64),
constant_op.constant([[12.0], [24.0]], dtypes.float32),
)
options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
model.save(save_path, options=options)
tf.keras.backend.clear_session()
del model
model = model_fn(table_device_)
model.load_weights(save_path)
params_ = model.get_layer('test_keras_save_restore').params
size = params_.size()
self.assertEqual(2, size)
[keys, values] = params_.export()
self.assertAllEqual([0, 1], keys)
self.assertAllEqual([[12.0], [24.0]], values)

# Check table device was assigned correctly
graph_path = os.path.join(save_path, 'saved_model.pb')
sm = SavedModel()
with open(graph_path, 'rb') as f:
sm.ParseFromString(f.read())
for mg in sm.meta_graphs:
for node in mg.graph_def.node:
if node.name == 'test_keras_save_restore-parameter_mht_1of1':
self.assertEqual(table_device_[0], node.device)


@test_util.run_all_in_graph_and_eager_modes
class SquashedEmbeddingLayerTest(test.TestCase):
Expand Down Expand Up @@ -218,6 +276,8 @@ def test_create(self):
key_dtypes = [dtypes.int64]

value_dtypes = [dtypes.float32, dtypes.float64]
if test_util.is_gpu_available():
value_dtypes = [dtypes.float32]
initializers = [
tf.keras.initializers.RandomNormal(),
tf.keras.initializers.RandomUniform()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
checkpoint=True,
init_size=0,
config=None,
device='',
):
"""Creates an empty `CuckooHashTable` object.
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
self._checkpoint = checkpoint
self._key_dtype = key_dtype
self._value_dtype = value_dtype
self._device = device
self._init_size = init_size
self._name = name

Expand Down Expand Up @@ -123,15 +125,16 @@ def _create_resource(self):
# explicitly specified.
use_node_name_sharing = self._checkpoint and self._shared_name is None

table_ref = cuckoo_ops.tfra_cuckoo_hash_table_of_tensors(
shared_name=self._shared_name,
use_node_name_sharing=use_node_name_sharing,
key_dtype=self._key_dtype,
value_dtype=self._value_dtype,
value_shape=self._default_value.get_shape(),
init_size=self._init_size,
name=self._name,
)
with ops.device(self._device):
table_ref = cuckoo_ops.tfra_cuckoo_hash_table_of_tensors(
shared_name=self._shared_name,
use_node_name_sharing=use_node_name_sharing,
key_dtype=self._key_dtype,
value_dtype=self._value_dtype,
value_shape=self._default_value.get_shape(),
init_size=self._init_size,
name=self._name,
)

if context.executing_eagerly():
self._table_name = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def create(self,
name=None,
checkpoint=None,
init_size=None,
config=None):
config=None,
device=None):

raise NotImplementedError('create function must be implemented')

Expand All @@ -74,6 +75,7 @@ def create(
checkpoint=None,
init_size=None,
config=None,
device=None,
):
self.key_dtype = key_dtype
self.value_dtype = value_dtype
Expand All @@ -82,15 +84,17 @@ def create(
self.checkpoint = checkpoint
self.init_size = init_size
self.config = config
self.device = device

return de.CuckooHashTable(
key_dtype=key_dtype,
value_dtype=value_dtype,
default_value=default_value,
name=name,
checkpoint=checkpoint,
init_size=init_size,
config=config,
key_dtype=self.key_dtype,
value_dtype=self.value_dtype,
default_value=self.default_value,
name=self.name,
checkpoint=self.checkpoint,
init_size=self.init_size,
config=self.config,
device=self.device,
)

def get_config(self):
Expand All @@ -106,6 +110,7 @@ def get_config(self):
'checkpoint': self.checkpoint,
'init_size': self.init_size,
'config': self.config,
'device': self.device,
}
return config

Expand Down Expand Up @@ -181,16 +186,42 @@ def create(
checkpoint=None,
init_size=None,
config=None,
device=None,
):
real_config = config if config is not None else self.config
if not isinstance(real_config, RedisTableConfig):
self.key_dtype = key_dtype
self.value_dtype = value_dtype
self.default_value = default_value
self.name = name
self.checkpoint = checkpoint
self.init_size = init_size
self.config = config if config is not None else self.config
self.device = device
if not isinstance(self.config, RedisTableConfig):
raise TypeError("config should be instance of 'config', but got ",
str(type(real_config)))
str(type(self.config)))
return de.RedisTable(
key_dtype=key_dtype,
value_dtype=value_dtype,
default_value=default_value,
name=name,
checkpoint=checkpoint,
key_dtype=self.key_dtype,
value_dtype=self.value_dtype,
default_value=self.default_value,
name=self.name,
checkpoint=self.checkpoint,
config=self.config,
device=self.device,
)

def get_config(self):
if not context.executing_eagerly():
raise RuntimeError(
'Unsupported to serialize python object of RedisTableCreator.')

config = {
'key_dtype': self.key_dtype,
'value_dtype': self.value_dtype,
'default_value': self.default_value.numpy(),
'name': self.name,
'checkpoint': self.checkpoint,
'init_size': self.init_size,
'config': self.config,
'device': self.device,
}
return config
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def _get_default_devices():
name=self._make_name(idx),
checkpoint=self.checkpoint,
init_size=int(self.init_size / self.shard_num),
device=self.devices[idx],
)
self._tables.append(mht)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
name="RedisTable",
checkpoint=False,
config=None,
device='',
):
"""
Creates an empty `RedisTable` object.
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(
self._checkpoint = checkpoint
self._key_dtype = key_dtype
self._value_dtype = value_dtype
self._device = device
self._name = name
self._embedding_name = (self._name.split('_mht_', 1))[0]
self._config = config
Expand Down Expand Up @@ -251,15 +253,17 @@ def _create_resource(self):
# training to work correctly. Use the node name if no shared_name has been
# explicitly specified.
use_node_name_sharing = self._checkpoint and self._shared_name is None
table_ref = redis_table_ops.tfra_redis_table_of_tensors(
shared_name=self._shared_name,
use_node_name_sharing=use_node_name_sharing,
key_dtype=self._key_dtype,
value_dtype=self._value_dtype,
value_shape=self._default_value.get_shape(),
embedding_name=self._embedding_name,
redis_config_abs_dir=self._config.redis_config_abs_dir,
redis_config_abs_dir_env=self._config.redis_config_abs_dir_env)

with ops.device(self._device):
table_ref = redis_table_ops.tfra_redis_table_of_tensors(
shared_name=self._shared_name,
use_node_name_sharing=use_node_name_sharing,
key_dtype=self._key_dtype,
value_dtype=self._value_dtype,
value_shape=self._default_value.get_shape(),
embedding_name=self._embedding_name,
redis_config_abs_dir=self._config.redis_config_abs_dir,
redis_config_abs_dir_env=self._config.redis_config_abs_dir_env)

if context.executing_eagerly():
self._table_name = None
Expand Down

0 comments on commit 77500c7

Please sign in to comment.