-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
ethan01.zhan
committed
Mar 7, 2022
1 parent
6872e62
commit 52c9f63
Showing
3 changed files
with
386 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
218 changes: 218 additions & 0 deletions
218
tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/warm_start_util_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Unit tests of warm-start util""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import glob | ||
import itertools | ||
import math | ||
import numpy as np | ||
import os | ||
import shutil | ||
|
||
from tensorflow_recommenders_addons import dynamic_embedding as de | ||
|
||
try: | ||
from tensorflow.python.keras.initializers import initializers_v2 as kinit2 | ||
except ImportError: | ||
kinit2 = None | ||
pass # for compatible with TF < 2.3.x | ||
|
||
from tensorflow.core.protobuf import cluster_pb2 | ||
from tensorflow.core.protobuf import config_pb2 | ||
from tensorflow.python.eager import context | ||
from tensorflow.python.framework import constant_op | ||
from tensorflow.python.framework import dtypes | ||
from tensorflow.python.framework import ops | ||
from tensorflow.python.framework import sparse_tensor | ||
from tensorflow.python.framework import tensor_shape | ||
from tensorflow.python.framework import test_util | ||
from tensorflow.python.keras import initializers as keras_init_ops | ||
from tensorflow.python.ops import array_ops | ||
from tensorflow.python.ops import embedding_ops | ||
from tensorflow.python.ops import gen_array_ops | ||
from tensorflow.python.ops import init_ops | ||
from tensorflow.python.ops import math_ops | ||
from tensorflow.python.ops import resources | ||
from tensorflow.python.ops import script_ops | ||
from tensorflow.python.ops import variables | ||
from tensorflow.python.ops import variable_scope | ||
from tensorflow.python.platform import test | ||
from tensorflow.python.training import device_setter | ||
from tensorflow.python.training import saver | ||
from tensorflow.python.training import server_lib | ||
from tensorflow.python.util import compat | ||
from tensorflow_recommenders_addons import dynamic_embedding as de | ||
|
||
import tensorflow as tf | ||
|
||
|
||
@test_util.deprecated_graph_mode_only | ||
class WarmStartUtilTest(test.TestCase): | ||
|
||
def _test_warm_start(self, num_shards, use_regex): | ||
devices = ["/cpu:0" for _ in range(num_shards)] | ||
ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") | ||
id_list = [x for x in range(100)] | ||
val_list = [[x] for x in range(100)] | ||
|
||
emb_name = "t100_{}_{}".format(num_shards, use_regex) | ||
with self.session(graph=ops.Graph()) as sess: | ||
embeddings = de.get_variable(emb_name, | ||
dtypes.int64, | ||
dtypes.float32, | ||
devices=devices, | ||
initializer=0.0) | ||
ids = constant_op.constant(id_list, dtype=dtypes.int64) | ||
vals = constant_op.constant(val_list, dtype=dtypes.float32) | ||
self.evaluate(embeddings.upsert(ids, vals)) | ||
save = saver.Saver(var_list=[embeddings]) | ||
save.save(sess, ckpt_prefix) | ||
|
||
with self.session(graph=ops.Graph()) as sess: | ||
embeddings = de.get_variable(emb_name, | ||
dtypes.int64, | ||
dtypes.float32, | ||
devices=devices, | ||
initializer=0.0) | ||
ids = constant_op.constant(id_list, dtype=dtypes.int64) | ||
emb = de.embedding_lookup(embeddings, ids, name="lookup") | ||
sess.graph.add_to_collection(de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, | ||
embeddings) | ||
vars_to_warm_start = [embeddings] | ||
if use_regex: | ||
vars_to_warm_start = [".*t100.*"] | ||
|
||
restore_op = de.warm_start(ckpt_to_initialize_from=ckpt_prefix, | ||
vars_to_warm_start=vars_to_warm_start) | ||
self.evaluate(restore_op) | ||
self.assertAllEqual(emb, val_list) | ||
|
||
def _test_warm_start_rename(self, num_shards, use_regex): | ||
devices = ["/cpu:0" for _ in range(num_shards)] | ||
ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") | ||
id_list = [x for x in range(100)] | ||
val_list = [[x] for x in range(100)] | ||
|
||
emb_name = "t200_{}_{}".format(num_shards, use_regex) | ||
with self.session(graph=ops.Graph()) as sess: | ||
embeddings = de.get_variable("save_{}".format(emb_name), | ||
dtypes.int64, | ||
dtypes.float32, | ||
devices=devices, | ||
initializer=0.0) | ||
ids = constant_op.constant(id_list, dtype=dtypes.int64) | ||
vals = constant_op.constant(val_list, dtype=dtypes.float32) | ||
self.evaluate(embeddings.upsert(ids, vals)) | ||
save = saver.Saver(var_list=[embeddings]) | ||
save.save(sess, ckpt_prefix) | ||
|
||
with self.session(graph=ops.Graph()) as sess: | ||
embeddings = de.get_variable("restore_{}".format(emb_name), | ||
dtypes.int64, | ||
dtypes.float32, | ||
devices=devices, | ||
initializer=0.0) | ||
ids = constant_op.constant(id_list, dtype=dtypes.int64) | ||
emb = de.embedding_lookup(embeddings, ids, name="lookup") | ||
sess.graph.add_to_collection(de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, | ||
embeddings) | ||
vars_to_warm_start = [embeddings] | ||
if use_regex: | ||
vars_to_warm_start = [".*t200.*"] | ||
|
||
restore_op = de.warm_start(ckpt_to_initialize_from=ckpt_prefix, | ||
vars_to_warm_start=vars_to_warm_start, | ||
var_name_to_prev_var_name={ | ||
"restore_{}".format(emb_name): | ||
"save_{}".format(emb_name) | ||
}) | ||
self.evaluate(restore_op) | ||
self.assertAllEqual(emb, val_list) | ||
|
||
def _test_warm_start_estimator(self, num_shards, use_regex): | ||
devices = ["/cpu:0" for _ in range(num_shards)] | ||
ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") | ||
id_list = [x for x in range(100)] | ||
val_list = [[x] for x in range(100)] | ||
|
||
emb_name = "t300_{}_{}".format(num_shards, use_regex) | ||
with self.session(graph=ops.Graph()) as sess: | ||
embeddings = de.get_variable(emb_name, | ||
dtypes.int64, | ||
dtypes.float32, | ||
devices=devices, | ||
initializer=0.0) | ||
ids = constant_op.constant(id_list, dtype=dtypes.int64) | ||
vals = constant_op.constant(val_list, dtype=dtypes.float32) | ||
self.evaluate(embeddings.upsert(ids, vals)) | ||
save = saver.Saver(var_list=[embeddings]) | ||
save.save(sess, ckpt_prefix) | ||
|
||
def _input_fn(): | ||
dataset = tf.data.Dataset.from_tensor_slices({ | ||
'ids': | ||
constant_op.constant([[x] for x in id_list], dtype=dtypes.int64) | ||
}) | ||
return dataset | ||
|
||
def _model_fn(features, labels, mode, params): | ||
ids = features['ids'] | ||
embeddings = de.get_variable(emb_name, | ||
dtypes.int64, | ||
dtypes.float32, | ||
devices=devices, | ||
initializer=0.0) | ||
emb = de.embedding_lookup(embeddings, ids, name="lookup") | ||
emb.graph.add_to_collection(de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, | ||
embeddings) | ||
vars_to_warm_start = [embeddings] | ||
if use_regex: | ||
vars_to_warm_start = [".*t300.*"] | ||
|
||
warm_start_hook = de.WarmStartHook(ckpt_to_initialize_from=ckpt_prefix, | ||
vars_to_warm_start=vars_to_warm_start) | ||
return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.PREDICT, | ||
predictions=emb, | ||
prediction_hooks=[warm_start_hook]) | ||
|
||
predictor = tf.estimator.Estimator(model_fn=_model_fn) | ||
predictions = predictor.predict(_input_fn) | ||
pred_vals = [] | ||
for pred in predictions: | ||
pred_vals.append(pred) | ||
self.assertAllEqual(pred_vals, val_list) | ||
|
||
def test_warm_start(self): | ||
for num_shards in [1, 3]: | ||
self._test_warm_start(num_shards, True) | ||
self._test_warm_start(num_shards, False) | ||
|
||
def test_warm_start_rename(self): | ||
for num_shards in [1, 3]: | ||
self._test_warm_start_rename(num_shards, True) | ||
self._test_warm_start_rename(num_shards, False) | ||
|
||
def test_warm_start_estimator(self): | ||
for num_shards in [1, 3]: | ||
self._test_warm_start_estimator(num_shards, True) | ||
self._test_warm_start_estimator(num_shards, False) | ||
|
||
|
||
if __name__ == "__main__": | ||
test.main() |
166 changes: 166 additions & 0 deletions
166
tensorflow_recommenders_addons/dynamic_embedding/python/ops/warm_start_util.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright 2022 The TensorFlow Recommenders-Addons Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# lint-as: python3 | ||
"""warm-start util""" | ||
|
||
import collections | ||
import six | ||
import re | ||
|
||
from tensorflow.python.framework import errors | ||
from tensorflow.python.framework import ops | ||
from tensorflow.python.ops import control_flow_ops | ||
from tensorflow.python.ops import io_ops | ||
from tensorflow.python.ops import state_ops | ||
from tensorflow.python.ops import variables as variables_lib | ||
from tensorflow.python.ops import variable_scope | ||
from tensorflow.python.platform import tf_logging as logging | ||
from tensorflow.python.training import checkpoint_ops | ||
from tensorflow.python.training import checkpoint_utils | ||
from tensorflow.python.training import saver as saver_lib | ||
from tensorflow.python.training.saving import saveable_object_util | ||
from tensorflow.python.training.session_run_hook import SessionRunHook | ||
from tensorflow.python.util.tf_export import tf_export | ||
|
||
from tensorflow_recommenders_addons import dynamic_embedding as de | ||
|
||
|
||
def _get_de_variables(vars_to_warm_start): | ||
if isinstance(vars_to_warm_start, | ||
six.string_types) or vars_to_warm_start is None: | ||
list_of_vars = ops.get_collection(de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, | ||
scope=vars_to_warm_start) | ||
elif isinstance(vars_to_warm_start, list): | ||
if all(isinstance(v, six.string_types) for v in vars_to_warm_start): | ||
list_of_vars = [] | ||
for v in vars_to_warm_start: | ||
list_of_vars += ops.get_collection( | ||
de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, scope=v) | ||
elif all(isinstance(v, de.Variable) for v in vars_to_warm_start): | ||
list_of_vars = vars_to_warm_start | ||
else: | ||
raise ValueError("If `vars_to_warm_start` is a list, it must be a " | ||
"`de.Variable` or `str`. Given types are {}".format( | ||
type(vars_to_warm_start))) | ||
else: | ||
raise ValueError("`vars_to_warm_start must be a `list` or `str`. Given " | ||
"type is {}".format(type(vars_to_warm_start))) | ||
|
||
de_variables = [] | ||
for v in list_of_vars: | ||
t = [v] if not isinstance(v, list) else v | ||
de_variables.append(v) | ||
|
||
return de_variables | ||
|
||
|
||
def warm_start(ckpt_to_initialize_from, | ||
vars_to_warm_start=".*", | ||
var_name_to_prev_var_name=None): | ||
"""Warm-starts de.Variable using the given settings. | ||
Args: | ||
ckpt_to_initialize_from: [Required] A string specifying the directory with | ||
checkpoint file(s) or path to checkpoint from which to warm-start the | ||
model parameters. | ||
vars_to_warm_start: [Optional] One of the following: | ||
- A regular expression (string) that captures which variables to | ||
warm-start (see tf.compat.v1.get_collection). This expression will only | ||
consider variables in the TRAINABLE_VARIABLES collection -- if you need | ||
to warm-start non_TRAINABLE vars (such as optimizer accumulators or | ||
batch norm statistics), please use the below option. | ||
- A list of strings, each a regex scope provided to | ||
tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see | ||
tf.compat.v1.get_collection). For backwards compatibility reasons, | ||
this is separate from the single-string argument type. | ||
- A list of Variables to warm-start. If you do not have access to the | ||
`Variable` objects at the call site, please use the above option. | ||
- `None`, in which case only TRAINABLE variables specified in | ||
`var_name_to_vocab_info` will be warm-started. | ||
Defaults to `'.*'`, which warm-starts all variables in the | ||
TRAINABLE_VARIABLES collection. Note that this excludes variables such | ||
as accumulators and moving statistics from batch norm. | ||
Raises: | ||
ValueError: If saveable's spec.name not match pattern | ||
defined by de.Variable._make_name. | ||
""" | ||
|
||
def _replace_var_in_spec_name(spec_name, var_name): | ||
|
||
def _replace(m): | ||
return '{}_mht_{}of{}'.format(var_name, m.groups()[1], m.groups()[2]) | ||
|
||
out = re.sub(r'(\w+)_mht_(\d+)of(\d+)', _replace, spec_name) | ||
if out is None: | ||
raise ValueError( | ||
"Invalid sepc name, should match `{}_mht_{}of{}`, given %s" % | ||
spec_name) | ||
return out | ||
|
||
logging.info("Warm-starting from: {}".format(ckpt_to_initialize_from)) | ||
|
||
de_variables = _get_de_variables(vars_to_warm_start) | ||
if not var_name_to_prev_var_name: | ||
var_name_to_prev_var_name = {} | ||
|
||
ckpt_file = checkpoint_utils._get_checkpoint_filename(ckpt_to_initialize_from) | ||
assign_ops = [] | ||
for variable in de_variables: | ||
var_name = variable.name | ||
prev_var_name = var_name_to_prev_var_name.get(var_name) | ||
if prev_var_name: | ||
logging.debug("Warm-start variable: {}: prev_var_name: {}".format( | ||
var_name, prev_var_name or "Unchanged")) | ||
else: | ||
prev_var_name = var_name | ||
|
||
saveables = saveable_object_util.validate_and_slice_inputs([variable]) | ||
for saveable in saveables: | ||
restore_specs = [] | ||
for spec in saveable.specs: | ||
restore_specs.append((_replace_var_in_spec_name(spec.name, | ||
prev_var_name), | ||
spec.slice_spec, spec.dtype)) | ||
|
||
names, slices, dtypes = zip(*restore_specs) | ||
# Load tensors in cuckoo_hashtable op's device | ||
with ops.colocate_with(saveable.op._resource_handle.op): | ||
saveable_tensors = io_ops.restore_v2(ckpt_file, names, slices, dtypes) | ||
assign_ops.append(saveable.restore(saveable_tensors, None)) | ||
|
||
return control_flow_ops.group(assign_ops) | ||
|
||
|
||
class WarmStartHook(SessionRunHook): | ||
"""Warm-start hook for tf.estimator.Estimator | ||
""" | ||
|
||
def __init__(self, | ||
ckpt_to_initialize_from, | ||
vars_to_warm_start, | ||
var_name_to_prev_var_name=None): | ||
self._ckpt_to_initialize_from = ckpt_to_initialize_from | ||
self._vars_to_warm_start = vars_to_warm_start | ||
self._var_name_to_prev_var_name = var_name_to_prev_var_name | ||
|
||
def begin(self): | ||
self._restore_op = warm_start( | ||
ckpt_to_initialize_from=self._ckpt_to_initialize_from, | ||
vars_to_warm_start=self._vars_to_warm_start, | ||
var_name_to_prev_var_name=self._var_name_to_prev_var_name) | ||
|
||
def after_create_session(self, session, coord): | ||
session.run(self._restore_op) |