Skip to content

Commit

Permalink
add HkvHashTableExportWithScores op
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Jul 17, 2024
1 parent e30aa41 commit a846b09
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,26 @@ REGISTER_OP(PREFIX_OP_NAME(HkvHashTableSaveToFileSystem))
.Attr("dirpath_env: string")
.Attr("append_to_file: bool")
.Attr("buffer_size: int >= 1");
REGISTER_OP(PREFIX_OP_NAME(HkvHashTableExportWithScores))
.Input("table_handle: resource")
.Output("keys: key_dtype")
.Output("values: value_dtype")
.Output("scores: int64")
.Attr("key_dtype: type")
.Attr("value_dtype: type")
.Attr("split_size: int")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
ShapeHandle keys = c->UnknownShapeOfRank(1);
ShapeHandle values = c->UnknownShapeOfRank(1);
ShapeHandle scores = c->UnknownShapeOfRank(1);
ShapeAndType value_shape_and_type;
c->set_output(0, keys);
c->set_output(1, values);
c->set_output(2, scores);
return TFOkStatus;
});
REGISTER_OP(PREFIX_OP_NAME(HkvHashTableExportKeysAndScores))
.Input("table_handle: resource")
.Output("keys: key_dtype")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,29 @@ def test_backward(self):
model.fit(x, y, verbose=0)
self.assertAllEqual(emb_layer.params.size(), start)

def test_backward_adagrad(self):
if not context.executing_eagerly():
self.skipTest('Only test in eager mode')
init = tf.keras.initializers.RandomNormal(seed=0)
model = get_sequential_model(de.keras.layers.Embedding,
4,
initializer=init,
bp_v2=False,
name='go582')
optmz = tf.keras.optimizers.Adagrad(1E-4)
optmz = de.DynamicEmbeddingOptimizer(optmz)
emb_layer = model.layers[0]
model.compile(optimizer=optmz, loss='binary_crossentropy')
start = 0
batch_size = 10
for i in range(1, 10):
x = math_ops.range(start, start + batch_size * i, dtype=dtypes.int64)
x = tf.reshape(x, (batch_size, -1))
start += batch_size * i
y = tf.zeros((batch_size, 1), dtype=dtypes.float32)
model.fit(x, y, verbose=0)
self.assertAllEqual(emb_layer.params.size(), start)

def test_backward_bp_v2(self):
if not context.executing_eagerly():
self.skipTest('Only test in eager mode')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _check_saveable_and_redirect_new_de_dir(hvd_rank=0):
if hasattr(de_var, 'saveable'):
de_var.saveable._saver_config.save_path = de_dir

def _traverse_emb_layers_and_save(hvd_rank=0):
def _traverse_emb_layers_and_save(proc_size=1, proc_rank=0):
for var in model.variables:
if not hasattr(var, "params"):
continue
Expand All @@ -117,24 +117,24 @@ def _traverse_emb_layers_and_save(hvd_rank=0):
a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars
for de_opt_var in de_opt_vars:
de_opt_var.save_to_file_system(dirpath=de_dir,
proc_size=hvd.size(),
proc_rank=hvd.rank())
if hvd_rank == 0:
proc_size=proc_size,
proc_rank=proc_rank)
if proc_rank == 0:
# FileSystemSaver works well at rank 0.
continue
# save Dynamic Embedding Parameters
de_var.save_to_file_system(dirpath=de_dir,
proc_size=hvd.size(),
proc_rank=hvd.rank())
proc_size=proc_size,
proc_rank=proc_rank)

if hvd is None:
call_original_save_func()
_traverse_emb_layers_and_save(0)
_traverse_emb_layers_and_save()
else:
_check_saveable_and_redirect_new_de_dir(hvd.rank())
if hvd.rank() == 0:
call_original_save_func()
_traverse_emb_layers_and_save(hvd.rank())
_traverse_emb_layers_and_save(hvd.size, hvd.rank())
hvd.join() # Sync for avoiding rank conflict


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,50 @@ def test_export_keys_and_scores(self):

del table

@test_util.run_in_graph_and_eager_modes()
def test_export_with_scores(self):
if not is_gpu_available:
self.skipTest('Only test when gpu is available.')
key_dtype = dtypes.int64
value_dtype = dtypes.int32
dim = 8
for strategy in de.HkvEvictStrategy:
with self.session(use_gpu=True, config=default_config):
table = de.get_variable(
str(strategy),
key_dtype=key_dtype,
value_dtype=value_dtype,
initializer=0,
dim=dim,
init_size=1024,
kv_creator=de.HkvHashTableCreator(
config=de.HkvHashTableConfig(init_capacity=1024,
max_capacity=1024,
max_hbm_for_values=1024 * 64,
evict_strategy=strategy,
gen_scores_fn=gen_scores_fn)))
keys = constant_op.constant(
np.array([0, 1, 2, 3]).astype(_type_converter(key_dtype)),
key_dtype)
values = constant_op.constant(
_convert([[0] * dim, [1] * dim, [2] * dim, [3] * dim], value_dtype),
value_dtype)

self.evaluate(table.upsert(keys, values))

exported_keys, exported_values, exported_scores = self.evaluate(
table.export_with_scores(1))
self.assertAllEqual(np.sort(exported_keys), keys)
self.assertAllEqual(exported_values, values)
if strategy is de.HkvEvictStrategy.CUSTOMIZED:
self.assertAllEqual(np.sort(exported_scores), gen_scores_fn(keys))
elif strategy is de.HkvEvictStrategy.EPOCHLFU:
self.assertAllEqual(exported_scores, np.full((4), 1))
elif strategy is de.HkvEvictStrategy.LFU:
self.assertAllEqual(exported_scores, np.ones(4))

del table

def test_evict_strategy_lfu(self):
if not is_gpu_available:
self.skipTest('Only test when gpu is available.')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,27 @@ def _convert_anything_to_init(self, raw_init, dim):
else:
raise ValueError
except:
init = array_ops.fill([dim], array_ops.reshape(init, [-1])[0])

def is_indexable_and_nonempty(obj):
has_getitem = hasattr(obj, '__getitem__')
is_nonempty = hasattr(obj, '__len__') and len(obj) > 0
return has_getitem and is_nonempty

if isinstance(init, int) or isinstance(init, float):
first_element = init
elif not isinstance(init, tf.Tensor) and is_indexable_and_nonempty(init):
first_element = init[0]
else:
reshaped_init = array_ops.reshape(init, [-1])
size_of_reshaped_init = tf.size(reshaped_init)

def get_default_value():
default_value = 0.0 if self.value_dtype.is_floating else 0
return tf.constant(default_value, dtype=self.value_dtype)

first_element = tf.cond(tf.greater(size_of_reshaped_init, 0),
lambda: reshaped_init[0], get_default_value)
init = array_ops.fill([dim], first_element)
init = math_ops.cast(init, dtype=self.value_dtype)
return init

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,20 @@ def export_keys_and_scores(self, split_size, name=None):
split_size=split_size)
return keys, scores

def export_with_scores(self, split_size, name=None):
if not (split_size > 0 and isinstance(split_size, int)):
raise ValueError(f'split_size must be positive integer.')

with ops.name_scope(name, "%s_lookup_table_export_with_scores" % self.name,
[self.resource_handle]):
with ops.colocate_with(self.resource_handle):
keys, values, scores = hkv_ops.tfra_hkv_hash_table_export_with_scores(
self.resource_handle,
key_dtype=self._key_dtype,
value_dtype=self._value_dtype,
split_size=split_size)
return keys, values, scores

def save_to_file_system(self,
dirpath,
file_name=None,
Expand Down

0 comments on commit a846b09

Please sign in to comment.