diff --git a/tensorflow_recommenders_addons/dynamic_embedding/__init__.py b/tensorflow_recommenders_addons/dynamic_embedding/__init__.py index 4d6d34f5c..a03999c2c 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/__init__.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/__init__.py @@ -80,6 +80,8 @@ Variable,) from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_variable import ( GraphKeys,) +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.warm_start_util import ( + warm_start, WarmStartHook) from tensorflow_recommenders_addons.dynamic_embedding.python.ops.restrict_policies import ( RestrictPolicy, TimestampRestrictPolicy, diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/warm_start_util_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/warm_start_util_test.py new file mode 100644 index 000000000..9b0e7ba3c --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/warm_start_util_test.py @@ -0,0 +1,218 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests of warm-start util""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import itertools +import math +import numpy as np +import os +import shutil + +from tensorflow_recommenders_addons import dynamic_embedding as de + +try: + from tensorflow.python.keras.initializers import initializers_v2 as kinit2 +except ImportError: + kinit2 = None + pass # for compatible with TF < 2.3.x + +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.keras import initializers as keras_init_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resources +from tensorflow.python.ops import script_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test +from tensorflow.python.training import device_setter +from tensorflow.python.training import saver +from tensorflow.python.training import server_lib +from tensorflow.python.util import compat +from tensorflow_recommenders_addons import dynamic_embedding as de + +import tensorflow as tf + + +@test_util.deprecated_graph_mode_only +class WarmStartUtilTest(test.TestCase): + + def _test_warm_start(self, num_shards, use_regex): + devices = ["/cpu:0" for _ in range(num_shards)] + ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") + id_list = [x for x in range(100)] + val_list = [[x] for x in range(100)] + + emb_name = "t100_{}_{}".format(num_shards, use_regex) + with self.session(graph=ops.Graph()) as sess: + embeddings = de.get_variable(emb_name, + dtypes.int64, + dtypes.float32, + devices=devices, + initializer=0.0) + ids = constant_op.constant(id_list, dtype=dtypes.int64) + vals = constant_op.constant(val_list, dtype=dtypes.float32) + self.evaluate(embeddings.upsert(ids, vals)) + save = saver.Saver(var_list=[embeddings]) + save.save(sess, ckpt_prefix) + + with self.session(graph=ops.Graph()) as sess: + embeddings = de.get_variable(emb_name, + dtypes.int64, + dtypes.float32, + devices=devices, + initializer=0.0) + ids = constant_op.constant(id_list, dtype=dtypes.int64) + emb = de.embedding_lookup(embeddings, ids, name="lookup") + sess.graph.add_to_collection(de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, + embeddings) + vars_to_warm_start = [embeddings] + if use_regex: + vars_to_warm_start = [".*t100.*"] + + restore_op = de.warm_start(ckpt_to_initialize_from=ckpt_prefix, + vars_to_warm_start=vars_to_warm_start) + self.evaluate(restore_op) + self.assertAllEqual(emb, val_list) + + def _test_warm_start_rename(self, num_shards, use_regex): + devices = ["/cpu:0" for _ in range(num_shards)] + ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") + id_list = [x for x in range(100)] + val_list = [[x] for x in range(100)] + + emb_name = "t200_{}_{}".format(num_shards, use_regex) + with self.session(graph=ops.Graph()) as sess: + embeddings = de.get_variable("save_{}".format(emb_name), + dtypes.int64, + dtypes.float32, + devices=devices, + initializer=0.0) + ids = constant_op.constant(id_list, dtype=dtypes.int64) + vals = constant_op.constant(val_list, dtype=dtypes.float32) + self.evaluate(embeddings.upsert(ids, vals)) + save = saver.Saver(var_list=[embeddings]) + save.save(sess, ckpt_prefix) + + with self.session(graph=ops.Graph()) as sess: + embeddings = de.get_variable("restore_{}".format(emb_name), + dtypes.int64, + dtypes.float32, + devices=devices, + initializer=0.0) + ids = constant_op.constant(id_list, dtype=dtypes.int64) + emb = de.embedding_lookup(embeddings, ids, name="lookup") + sess.graph.add_to_collection(de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, + embeddings) + vars_to_warm_start = [embeddings] + if use_regex: + vars_to_warm_start = [".*t200.*"] + + restore_op = de.warm_start(ckpt_to_initialize_from=ckpt_prefix, + vars_to_warm_start=vars_to_warm_start, + var_name_to_prev_var_name={ + "restore_{}".format(emb_name): + "save_{}".format(emb_name) + }) + self.evaluate(restore_op) + self.assertAllEqual(emb, val_list) + + def _test_warm_start_estimator(self, num_shards, use_regex): + devices = ["/cpu:0" for _ in range(num_shards)] + ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") + id_list = [x for x in range(100)] + val_list = [[x] for x in range(100)] + + emb_name = "t300_{}_{}".format(num_shards, use_regex) + with self.session(graph=ops.Graph()) as sess: + embeddings = de.get_variable(emb_name, + dtypes.int64, + dtypes.float32, + devices=devices, + initializer=0.0) + ids = constant_op.constant(id_list, dtype=dtypes.int64) + vals = constant_op.constant(val_list, dtype=dtypes.float32) + self.evaluate(embeddings.upsert(ids, vals)) + save = saver.Saver(var_list=[embeddings]) + save.save(sess, ckpt_prefix) + + def _input_fn(): + dataset = tf.data.Dataset.from_tensor_slices({ + 'ids': + constant_op.constant([[x] for x in id_list], dtype=dtypes.int64) + }) + return dataset + + def _model_fn(features, labels, mode, params): + ids = features['ids'] + embeddings = de.get_variable(emb_name, + dtypes.int64, + dtypes.float32, + devices=devices, + initializer=0.0) + emb = de.embedding_lookup(embeddings, ids, name="lookup") + emb.graph.add_to_collection(de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, + embeddings) + vars_to_warm_start = [embeddings] + if use_regex: + vars_to_warm_start = [".*t300.*"] + + warm_start_hook = de.WarmStartHook(ckpt_to_initialize_from=ckpt_prefix, + vars_to_warm_start=vars_to_warm_start) + return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.PREDICT, + predictions=emb, + prediction_hooks=[warm_start_hook]) + + predictor = tf.estimator.Estimator(model_fn=_model_fn) + predictions = predictor.predict(_input_fn) + pred_vals = [] + for pred in predictions: + pred_vals.append(pred) + self.assertAllEqual(pred_vals, val_list) + + def test_warm_start(self): + for num_shards in [1, 3]: + self._test_warm_start(num_shards, True) + self._test_warm_start(num_shards, False) + + def test_warm_start_rename(self): + for num_shards in [1, 3]: + self._test_warm_start_rename(num_shards, True) + self._test_warm_start_rename(num_shards, False) + + def test_warm_start_estimator(self): + for num_shards in [1, 3]: + self._test_warm_start_estimator(num_shards, True) + self._test_warm_start_estimator(num_shards, False) + + +if __name__ == "__main__": + test.main() \ No newline at end of file diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/warm_start_util.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/warm_start_util.py new file mode 100644 index 000000000..def8c79ec --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/warm_start_util.py @@ -0,0 +1,194 @@ +# Copyright 2022 The TensorFlow Recommenders-Addons Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# lint-as: python3 +"""warm-start util""" + +import collections +import six +import re + +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_ops +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.saving import saveable_object_util +from tensorflow.python.training.session_run_hook import SessionRunHook +from tensorflow.python.util.tf_export import tf_export + +from tensorflow_recommenders_addons import dynamic_embedding as de + + +def _get_de_variables(vars_to_warm_start): + if isinstance(vars_to_warm_start, + six.string_types) or vars_to_warm_start is None: + list_of_vars = ops.get_collection(de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, + scope=vars_to_warm_start) + elif isinstance(vars_to_warm_start, list): + if all(isinstance(v, six.string_types) for v in vars_to_warm_start): + list_of_vars = [] + for v in vars_to_warm_start: + list_of_vars += ops.get_collection( + de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, scope=v) + elif all(isinstance(v, de.Variable) for v in vars_to_warm_start): + list_of_vars = vars_to_warm_start + else: + raise ValueError("If `vars_to_warm_start` is a list, it must be a " + "`de.Variable` or `str`. Given types are {}".format( + type(vars_to_warm_start))) + else: + raise ValueError("`vars_to_warm_start must be a `list` or `str`. Given " + "type is {}".format(type(vars_to_warm_start))) + + de_variables = [] + for v in list_of_vars: + t = [v] if not isinstance(v, list) else v + de_variables.append(v) + + return de_variables + + +def warm_start(ckpt_to_initialize_from, + vars_to_warm_start=".*", + var_name_to_prev_var_name=None): + """Warm-starts de.Variable using the given settings. + + Args: + ckpt_to_initialize_from: [Required] A string specifying the directory with + checkpoint file(s) or path to checkpoint from which to warm-start the + model parameters. + vars_to_warm_start: [Optional] One of the following: + - A regular expression (string) that captures which variables to + warm-start (see tf.compat.v1.get_collection). This expression will only + consider variables in the TRAINABLE_VARIABLES collection -- if you need + to warm-start non_TRAINABLE vars (such as optimizer accumulators or + batch norm statistics), please use the below option. + - A list of strings, each a regex scope provided to + tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see + tf.compat.v1.get_collection). For backwards compatibility reasons, + this is separate from the single-string argument type. + - A list of Variables to warm-start. If you do not have access to the + `Variable` objects at the call site, please use the above option. + - `None`, in which case only TRAINABLE variables specified in + `var_name_to_vocab_info` will be warm-started. + Defaults to `'.*'`, which warm-starts all variables in the + TRAINABLE_VARIABLES collection. Note that this excludes variables such + as accumulators and moving statistics from batch norm. + + Raises: + ValueError: If saveable's spec.name not match pattern + defined by de.Variable._make_name. + """ + + def _replace_var_in_spec_name(spec_name, var_name): + + def _replace(m): + return '{}_mht_{}of{}'.format(var_name, m.groups()[1], m.groups()[2]) + + out = re.sub(r'(\w+)_mht_(\d+)of(\d+)', _replace, spec_name) + if out is None: + raise ValueError( + "Invalid sepc name, should match `{}_mht_{}of{}`, given %s" % + spec_name) + return out + + logging.info("Warm-starting from: {}".format(ckpt_to_initialize_from)) + + de_variables = _get_de_variables(vars_to_warm_start) + if not var_name_to_prev_var_name: + var_name_to_prev_var_name = {} + + ckpt_file = checkpoint_utils._get_checkpoint_filename(ckpt_to_initialize_from) + assign_ops = [] + for variable in de_variables: + var_name = variable.name + prev_var_name = var_name_to_prev_var_name.get(var_name) + if prev_var_name: + logging.debug("Warm-start variable: {}: prev_var_name: {}".format( + var_name, prev_var_name or "Unchanged")) + else: + prev_var_name = var_name + + saveables = saveable_object_util.validate_and_slice_inputs([variable]) + for saveable in saveables: + restore_specs = [] + for spec in saveable.specs: + restore_specs.append((_replace_var_in_spec_name(spec.name, + prev_var_name), + spec.slice_spec, spec.dtype)) + + names, slices, dtypes = zip(*restore_specs) + # Load tensors in cuckoo_hashtable op's device + with ops.colocate_with(saveable.op._resource_handle.op): + saveable_tensors = io_ops.restore_v2(ckpt_file, names, slices, dtypes) + assign_ops.append(saveable.restore(saveable_tensors, None)) + + return control_flow_ops.group(assign_ops) + + +class WarmStartHook(SessionRunHook): + """Warm-start hook for tf.estimator.Estimator + """ + + def __init__(self, + ckpt_to_initialize_from, + vars_to_warm_start, + var_name_to_prev_var_name=None): + """Initializes a `WarmStartHook` + + Args: + ckpt_to_initialize_from: [Required] A string specifying the directory with + checkpoint file(s) or path to checkpoint from which to warm-start the + model parameters. + vars_to_warm_start: [Optional] One of the following: + - A regular expression (string) that captures which variables to + warm-start (see tf.compat.v1.get_collection). This expression will only + consider variables in the TRAINABLE_VARIABLES collection -- if you need + to warm-start non_TRAINABLE vars (such as optimizer accumulators or + batch norm statistics), please use the below option. + - A list of strings, each a regex scope provided to + tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see + tf.compat.v1.get_collection). For backwards compatibility reasons, + this is separate from the single-string argument type. + - A list of Variables to warm-start. If you do not have access to the + `Variable` objects at the call site, please use the above option. + - `None`, in which case only TRAINABLE variables specified in + `var_name_to_vocab_info` will be warm-started. + Defaults to `'.*'`, which warm-starts all variables in the + TRAINABLE_VARIABLES collection. Note that this excludes variables such + as accumulators and moving statistics from batch norm. + + Raises: + ValueError: If saveable's spec.name not match pattern + defined by de.Variable._make_name. + """ + self._ckpt_to_initialize_from = ckpt_to_initialize_from + self._vars_to_warm_start = vars_to_warm_start + self._var_name_to_prev_var_name = var_name_to_prev_var_name + + def begin(self): + self._restore_op = warm_start( + ckpt_to_initialize_from=self._ckpt_to_initialize_from, + vars_to_warm_start=self._vars_to_warm_start, + var_name_to_prev_var_name=self._var_name_to_prev_var_name) + + def after_create_session(self, session, coord): + session.run(self._restore_op) \ No newline at end of file