Skip to content

Commit

Permalink
Enable directly use dynamic_embedding.embedding_lookup series API ins…
Browse files Browse the repository at this point in the history
…ide tf.function scope
  • Loading branch information
Lifann committed Mar 22, 2022
1 parent 7a14cd4 commit 7cc07c3
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import math
import numpy as np
import os
import tensorflow as tf

from tensorflow_recommenders_addons import dynamic_embedding as de

Expand All @@ -47,9 +48,12 @@
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training import device_setter
from tensorflow.python.training import server_lib
Expand Down Expand Up @@ -1296,5 +1300,84 @@ def test_colocate_to_ids(self):
self.assertAllEqual(tw_q.device, '/job:dist/task:1')


@test_util.run_all_in_graph_and_eager_modes
class EmbeddingLookupEagerTest(test.TestCase):

def _create_input_and_params(self,
name,
batch_size=4,
nids=64,
embedding_size=1):
assert nids % batch_size == 0
ids = math_ops.range(0, nids, dtype=dtypes.int64)
ids = array_ops.reshape(ids, (batch_size, -1))
labels = array_ops.zeros((batch_size,), dtype=dtypes.float32)
devar = de.get_variable(name + '/dynamic_embedding',
dim=embedding_size,
initializer=tf.keras.initializers.Zeros())
tfvar = tf.Variable(tf.keras.initializers.Zeros()((nids, embedding_size),
dtype=tf.float32))
return ids, labels, devar, tfvar

def _loss_fn(self, params, ids, labels):

if isinstance(params, de.Variable):
embedding = de.embedding_lookup(params, ids)
elif isinstance(
params, (resource_variable_ops.ResourceVariable, variables.Variable)):
embedding = embedding_ops.embedding_lookup(params, ids)
else:
raise TypeError

logits = math_ops.reduce_mean(math_ops.reduce_sum(embedding, 1), 1)
entropy = nn_impl.sigmoid_cross_entropy_with_logits(logits=logits,
labels=labels)
loss = math_ops.reduce_mean(entropy)
return loss

def test_run_training_eagerly(self):
if not context.executing_eagerly():
self.skipTest('Only test functional API in eager mode.')

batch_size = 4
ids, labels, devar, tfvar = self._create_input_and_params('vns079',
embedding_size=1)
nsteps = 10

loss_fn = tf.function()(self._loss_fn)

def sorted_dynamic_embedding_value():
embedding_var = devar
optimizer = tf.keras.optimizers.Adam(1E-3)
optimizer = de.DynamicEmbeddingOptimizer(optimizer)

def var_fn():
return list(embedding_var.trainable_store.values())

for _ in range(nsteps):
optimizer.minimize(lambda: loss_fn(embedding_var, ids, labels), var_fn)

keys, values = embedding_var.export()
order = tf.argsort(keys)
return array_ops.gather(values, order)

def sorted_static_embedding_value():
embedding_var = tfvar
optimizer = tf.keras.optimizers.Adam(1E-3)
optimizer = de.DynamicEmbeddingOptimizer(optimizer)

def var_fn():
return [embedding_var]

for _ in range(nsteps):
optimizer.minimize(lambda: loss_fn(embedding_var, ids, labels), var_fn)

return embedding_var.read_value()

de_values = sorted_dynamic_embedding_value()
tf_values = sorted_static_embedding_value()
self.assertAllClose(de_values, tf_values)


if __name__ == "__main__":
test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export

_ANONYMOUS_TRAINABLE_STORE_KEY = '_anonymous_trainable_store_key'


class TrainableWrapper(resource_variable_ops.ResourceVariable):
"""
Expand Down Expand Up @@ -548,28 +550,58 @@ def initial_value():
if params.trainable:
collections += [ops.GraphKeys.TRAINABLE_VARIABLES]

def _create_trainable(trainable_name):
return de.TrainableWrapper(params,
ids,
max_norm=max_norm,
initial_value=initial_value,
dtype=params.value_dtype,
trainable=params.trainable,
collections=collections,
model_mode=ModelMode.CURRENT_SETTING,
name=trainable_name)
def _create_or_get_trainable(trainable_name):
if trainable_name is None:
if context.executing_eagerly():
raise ValueError(
'Must provide a name for embedding_lookup when using eager execution.'
)
trainable_name = ops.get_default_graph().unique_name(
_ANONYMOUS_TRAINABLE_STORE_KEY)
if not context.executing_eagerly() and not ops.inside_function():
wrapper = de.TrainableWrapper(params,
ids,
max_norm=max_norm,
initial_value=initial_value,
dtype=params.value_dtype,
trainable=params.trainable,
collections=collections,
model_mode=ModelMode.CURRENT_SETTING,
name=trainable_name)
params._trainable_store[trainable_name] = wrapper
return wrapper
else:
with ops.init_scope():
shadow = params._trainable_store.get(trainable_name, None)
if shadow is None:
shadow = de.shadow_ops.ShadowVariable(
params,
name=trainable_name,
max_norm=max_norm,
trainable=params.trainable,
model_mode=ModelMode.CURRENT_SETTING)
params._trainable_store[trainable_name] = shadow
return shadow

with ops.colocate_with(ids, ignore_existing=True):
if context.executing_eagerly():
trainable_ = params._trainable_store.get(name, None)
if trainable_ is None:
trainable_ = _create_trainable(name)
params._trainable_store[name] = trainable_
else:
trainable_._reset_ids(ids)
else:
trainable_ = _create_trainable(name)
params._trainable_store[name] = trainable_
trainable_ = _create_or_get_trainable(name)

if isinstance(trainable_, de.shadow_ops.ShadowVariable):
embeddings = de.shadow_ops.embedding_lookup(
trainable_,
ids,
partition_strategy=partition_strategy,
name=name,
validate_indices=validate_indices)
if return_trainable:
if not context.executing_eagerly():
raise NotImplementedError(
'return_trainable currently is not implemented when using tf.function.'
' Please use `Variable.trainable_store` or `Variable.get_trainable_by_name`'
' to access the shadow trainable variable if call `embedding_lookup` series'
' APIs inside tf.function scope.')
return embeddings, trainable_
return embeddings

embeddings = array_ops.identity(trainable_)
embeddings = array_ops.reshape(embeddings, shape=embeddings_shape)
Expand Down Expand Up @@ -738,7 +770,7 @@ def embedding_lookup_sparse(
embeddings, trainable_ = embedding_lookup(
params,
ids,
name=name + "/embedding_lookup",
name=name + '/embedding_lookup',
partition_strategy=partition_strategy,
max_norm=max_norm,
return_trainable=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,43 @@ def get_slot_variables(self, optimizer):
continue
return slots

def get_trainable_by_name(self, name):
"""
Get trainable shadow variable when using eager execution.
Example:
```python
from tensorflow_recommenders_addons import dynamic_embedding as de
init = tf.keras.initializers.RandomNormal()
params = de.get_variable('foo', dim=4, initializer=init)
optimizer = tf.keras.optimizers.Adam(1E-3)
optimizer = de.DynamicEmbeddingOptimizer(optimizer)
@tf.function
def loss_fn(ids):
emb = de.embedding_lookup(params, ids, name='user_embedding')
emb = tf.math.reduce_sum(emb, axis=1)
loss = tf.reduce_mean(emb)
return loss
for i in range(10):
optimizer.minimize(lambda: loss_fn(ids),
var_list=[params.get_eager_trainable_by_name('user_embedding')])
```
Args:
name: str. Name used to get the trainable shadow to the Variable.
Returns:
A ShadowVariable object refers to the specific name.
Raises:
RuntimeError: if not in eager mode.
"""
if not isinstance(name, str):
raise TypeError('name should be a string')
return self._trainable_store.get(name, None)

def _gather_saveables_for_checkpoint(self):
g = ops.get_default_graph()
if context.executing_eagerly() or g._functions:
Expand All @@ -678,6 +715,10 @@ def _gather_saveables_for_checkpoint(self):
saveables[saveable.keywords["name"]] = saveable
return saveables

@property
def trainable_store(self):
return self._trainable_store


@tf_export("dynamic_embedding.get_variable")
def get_variable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.tracking import base as trackable

Expand Down Expand Up @@ -172,6 +173,36 @@ def value(self, do_prefetch=False):
with ops.colocate_with(None, ignore_existing=True):
return self._read_variable_op(do_prefetch=do_prefetch)

def assign(self, value, use_locking=None, name=None, read_value=True):
"""
Assigns a new value to this variable.
To discriminate with ResourceVariable, the shadow always uses a
variant space to hold the temporary embedding lookup buffer.
Args:
value: A `Tensor`. The new value for this variable.
use_locking: If `True`, use locking during the assignment.
name: The name to use for the assignment.
read_value: A `bool`. Whether to read and return the new value of the
variable or not.
Returns:
If `read_value` is `True`, this method will return the new value of the
variable after the assignment has completed. Otherwise, when in graph mode
it will return the `Operation` that does the assignment, and when in eager
mode it will return `None`.
"""
# Note: not depending on the cached value here since this can be used to
# initialize the variable.
with resource_variable_ops._handle_graph(self.handle):
value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
assign_op = gen_resource_variable_ops.assign_variable_op(self.handle,
value_tensor,
name=name)
if read_value:
return self._lazy_read(assign_op)
return assign_op

def _reset_ids(self, ids):
return self.ids.assign(ids, use_locking=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function

import threading

from tensorflow_recommenders_addons import dynamic_embedding as de
from tensorflow_recommenders_addons import embedding_variable as ev

Expand Down

0 comments on commit 7cc07c3

Please sign in to comment.