diff --git a/demo/bpv2/README.md b/demo/bpv2/README.md new file mode 100644 index 000000000..1e09656cd --- /dev/null +++ b/demo/bpv2/README.md @@ -0,0 +1,6 @@ +# A simple training demo for `bp_v2` + +## start train: +``` +sh train.sh +``` diff --git a/demo/bpv2/main.py b/demo/bpv2/main.py new file mode 100644 index 000000000..b5d2a2d89 --- /dev/null +++ b/demo/bpv2/main.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +import numpy as np +import tensorflow as tf +import tensorflow_recommenders_addons as tfra + +batch_size = 128 +vocab_size = 10000 +embed_size = 64 + + +def make_word_index(): + word_index = tf.keras.datasets.imdb.get_word_index() + word_index = {k: (v + 3) for k, v in word_index.items()} + word_index[""] = 0 + word_index[""] = 1 + word_index[""] = 2 # unknown + word_index[""] = 3 + reverse_word_index = dict([ + (value, key) for (key, value) in word_index.items() + ]) + return word_index, reverse_word_index + + +def decode_review(text): + return ' '.join([reverse_word_index.get(i, '?') for i in text]) + + +word_index, reverse_word_index = make_word_index() + + +def get_data(): + (train_data, train_labels), (test_data, + test_labels) = tf.keras.datasets.imdb.load_data( + num_words=10000, path='imdb-0') + train_data = tf.keras.preprocessing.sequence.pad_sequences( + train_data, value=word_index[""], padding='post', maxlen=256) + test_data = tf.keras.preprocessing.sequence.pad_sequences( + test_data, value=word_index[""], padding='post', maxlen=256) + x_val = train_data[:1024] + x_train = train_data[:1024] + y_val = train_labels[:1024] + y_train = train_labels[:1024] + return x_train, y_train, x_val, y_val + + +x_train, y_train, x_val, y_val = get_data() + + +def input_fn_train(): + dataset = tf.data.Dataset.from_tensor_slices(({'x': x_train}, y_train)) + dataset = dataset.shuffle(1000).repeat().batch(batch_size) + return dataset + + +def input_fn_val(): + dataset = tf.data.Dataset.from_tensor_slices(({'x': x_val}, y_val)) + dataset = dataset.shuffle(1000).repeat().batch(batch_size) + return dataset + + +def model_fn(features, labels, mode): + x = features['x'] + x = tf.reshape(x, [-1]) + uniqx, uniqxidx = tf.unique(x) + w = tfra.dynamic_embedding.get_variable( + name='w', + initializer=tf.random_normal_initializer(0, 0.5), + dim=embed_size, + bp_v2=True, # this is the only thing you need to do to enable bpv2 + key_dtype=tf.int32) + uniqe = tfra.dynamic_embedding.embedding_lookup(params=w, ids=uniqx, name='a') + e = tf.gather(uniqe, uniqxidx) + e = tf.reshape(e, [-1, 256, embed_size]) + + embmean = tf.reduce_mean(e, axis=1) + fc1 = tf.layers.dense(embmean, 16, activation=tf.nn.relu) + logits = tf.layers.dense(fc1, 2, activation=None) + predictions = { + "classes": tf.argmax(input=logits, axis=1), + "probabilities": tf.nn.softmax(logits, name="softmax_tensor"), + } + + y = tf.one_hot(tf.cast(labels, tf.int32), 2, 1, 0) + loss = tf.losses.softmax_cross_entropy(y, logits) + if mode == tf.estimator.ModeKeys.TRAIN: + opt = tf.compat.v1.train.AdamOptimizer(0.01) + opt = tfra.dynamic_embedding.DynamicEmbeddingOptimizer(opt) + global_step = tf.compat.v1.train.get_or_create_global_step() + with tf.compat.v1.control_dependencies([ + tf.print('step', global_step, 'loss', loss), + ]): + train_op = opt.minimize(loss, global_step=global_step) + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) + else: + eval_metric_ops = { + "accuracy": + tf.metrics.accuracy(labels=labels, + predictions=predictions["classes"]) + } + return tf.estimator.EstimatorSpec(mode=mode, + loss=loss, + eval_metric_ops=eval_metric_ops) + + +config = tf.estimator.RunConfig(save_checkpoints_steps=None, + save_checkpoints_secs=tf.int64.max, + model_dir=None, + log_step_count_steps=1) +classifier = tf.estimator.Estimator(model_fn=model_fn, config=config) + +tf.estimator.train_and_evaluate( + classifier, + train_spec=tf.estimator.TrainSpec(input_fn=input_fn_train, max_steps=3000), + eval_spec=tf.estimator.EvalSpec(input_fn=input_fn_val, steps=1000)) diff --git a/demo/bpv2/train.sh b/demo/bpv2/train.sh new file mode 100644 index 000000000..1fa797ce9 --- /dev/null +++ b/demo/bpv2/train.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +python main.py diff --git a/docs/api_docs/tfra/dynamic_embedding/Variable.md b/docs/api_docs/tfra/dynamic_embedding/Variable.md index 9b18e70e7..48f258fef 100644 --- a/docs/api_docs/tfra/dynamic_embedding/Variable.md +++ b/docs/api_docs/tfra/dynamic_embedding/Variable.md @@ -58,7 +58,8 @@ __init__( trainable=True, checkpoint=True, init_size=0, - restrict_policy=None + restrict_policy=None, + bp_v2=False, ) ``` @@ -105,6 +106,9 @@ def default_partition_fn(keys, shard_num): size of variable. If in training program, the variable is updated by optimizer, then the sparse slot variables in optimizer are also be restricted. +* `bp_v2`:update parameters by *updating* instead of *setting*, which solves + the race condition problem among workers during backpropagation in large-scale + distributed asynchronous training. #### Returns: diff --git a/docs/api_docs/tfra/dynamic_embedding/get_variable.md b/docs/api_docs/tfra/dynamic_embedding/get_variable.md index 58cb16e49..2a9e120fb 100644 --- a/docs/api_docs/tfra/dynamic_embedding/get_variable.md +++ b/docs/api_docs/tfra/dynamic_embedding/get_variable.md @@ -37,7 +37,8 @@ tfra.dynamic_embedding.get_variable( trainable=True, checkpoint=True, init_size=0, - restrict_policy=None + restrict_policy=None, + bp_v2=False, ) ``` @@ -80,6 +81,9 @@ def default_partition_fn(keys, shard_num): size of variable. If in training program, the variable is updated by optimizer, then the sparse slot variables in optimizer are also be restricted. +* `bp_v2`:update parameters by *updating* instead of *setting*, which solves + the race condition problem among workers during backpropagation in large-scale + distributed asynchronous training. #### Returns: diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_optimizer_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_optimizer_test.py index 6b1820d79..080ad1b20 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_optimizer_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_optimizer_test.py @@ -102,7 +102,7 @@ def _test_dir(temp_dir, test_name): class CommonTrainableTestV1Base(object): - def common_minimize_trainable(self, base_opt, test_opt, name): + def common_minimize_trainable(self, base_opt, test_opt, name, bp_v2): raise NotImplementedError def device_check(self, de): @@ -113,13 +113,19 @@ def device_check(self, de): def test_adadelta_minimize_trainable(self): base_opt = adadelta.AdadeltaOptimizer(1.0) test_opt = adadelta.AdadeltaOptimizer(1.0) - self.common_minimize_trainable(base_opt, test_opt, name="adadelta") + self.common_minimize_trainable(base_opt, + test_opt, + name="adadelta", + bp_v2=False) @test_util.deprecated_graph_mode_only def test_adagrad_minimize_trainable(self): base_opt = adagrad.AdagradOptimizer(1.0) test_opt = adagrad.AdagradOptimizer(1.0) - self.common_minimize_trainable(base_opt, test_opt, name="adagrad") + self.common_minimize_trainable(base_opt, + test_opt, + name="adagrad", + bp_v2=False) @test_util.deprecated_graph_mode_only def test_adagradda_minimize_trainable(self): @@ -127,43 +133,55 @@ def test_adagradda_minimize_trainable(self): base_opt = adagrad_da.AdagradDAOptimizer(1.0, base_gs) test_opt = adagrad_da.AdagradDAOptimizer(1.0, base_gs) - self.common_minimize_trainable(base_opt, test_opt, name="adagrad_da") + self.common_minimize_trainable(base_opt, + test_opt, + name="adagrad_da", + bp_v2=False) @test_util.deprecated_graph_mode_only def test_ftrl_minimize_trainable(self): base_opt = ftrl.FtrlOptimizer(1.0) test_opt = ftrl.FtrlOptimizer(1.0) - self.common_minimize_trainable(base_opt, test_opt, name="ftrl") + self.common_minimize_trainable(base_opt, test_opt, name="ftrl", bp_v2=False) @test_util.deprecated_graph_mode_only def test_proximal_adagrad_minimize_trainable(self): base_opt = proximal_adagrad.ProximalAdagradOptimizer(1.0) test_opt = proximal_adagrad.ProximalAdagradOptimizer(1.0) - self.common_minimize_trainable(base_opt, test_opt, name="proximal_adagrad") + self.common_minimize_trainable(base_opt, + test_opt, + name="proximal_adagrad", + bp_v2=False) @test_util.deprecated_graph_mode_only def test_proximalsgd_minimize_trainable(self): base_opt = pgd.ProximalGradientDescentOptimizer(1.0) test_opt = pgd.ProximalGradientDescentOptimizer(1.0) - self.common_minimize_trainable(base_opt, test_opt, name="proximal_sgd") + self.common_minimize_trainable(base_opt, + test_opt, + name="proximal_sgd", + bp_v2=False) @test_util.deprecated_graph_mode_only def test_momentum_minimize_trainable(self): base_opt = momentum.MomentumOptimizer(1.0, momentum=0.9) test_opt = momentum.MomentumOptimizer(1.0, momentum=0.9) - self.common_minimize_trainable(base_opt, test_opt, name="momentum") + self.common_minimize_trainable(base_opt, + test_opt, + name="momentum", + bp_v2=False) @test_util.deprecated_graph_mode_only def test_sgd_minimize_trainable(self): base_opt = gradient_descent.GradientDescentOptimizer(1.0) test_opt = gradient_descent.GradientDescentOptimizer(1.0) - self.common_minimize_trainable(base_opt, test_opt, name="sgd") + self.common_minimize_trainable(base_opt, test_opt, name="sgd", bp_v2=False) @test_util.deprecated_graph_mode_only def test_adam_minimize_trainable(self): base_opt = adam.AdamOptimizer(1.0) test_opt = adam.AdamOptimizer(1.0) - self.common_minimize_trainable(base_opt, test_opt, name="adam") + self.common_minimize_trainable(base_opt, test_opt, name="adam", bp_v2=False) @test_util.deprecated_graph_mode_only def test_rmsprop_minimize_trainable(self): @@ -172,7 +190,92 @@ def test_rmsprop_minimize_trainable(self): test_opt = rmsprop.RMSPropOptimizer(1.0, centered=centered_) self.common_minimize_trainable(base_opt, test_opt, - name="rmsprop" + str(centered_)) + name="rmsprop" + str(centered_), + bp_v2=False) + + @test_util.deprecated_graph_mode_only + def test_adadelta_minimize_trainable_bpv2(self): + base_opt = adadelta.AdadeltaOptimizer(1.0) + test_opt = adadelta.AdadeltaOptimizer(1.0) + self.common_minimize_trainable(base_opt, + test_opt, + name="adadelta", + bp_v2=True) + + @test_util.deprecated_graph_mode_only + def test_adagrad_minimize_trainable_bpv2(self): + base_opt = adagrad.AdagradOptimizer(1.0) + test_opt = adagrad.AdagradOptimizer(1.0) + self.common_minimize_trainable(base_opt, + test_opt, + name="adagrad", + bp_v2=True) + + @test_util.deprecated_graph_mode_only + def test_adagradda_minimize_trainable_bpv2(self): + base_gs = training_util.create_global_step() + + base_opt = adagrad_da.AdagradDAOptimizer(1.0, base_gs) + test_opt = adagrad_da.AdagradDAOptimizer(1.0, base_gs) + self.common_minimize_trainable(base_opt, + test_opt, + name="adagrad_da", + bp_v2=True) + + @test_util.deprecated_graph_mode_only + def test_ftrl_minimize_trainable_bpv2(self): + base_opt = ftrl.FtrlOptimizer(1.0) + test_opt = ftrl.FtrlOptimizer(1.0) + self.common_minimize_trainable(base_opt, test_opt, name="ftrl", bp_v2=True) + + @test_util.deprecated_graph_mode_only + def test_proximal_adagrad_minimize_trainable_bpv2(self): + base_opt = proximal_adagrad.ProximalAdagradOptimizer(1.0) + test_opt = proximal_adagrad.ProximalAdagradOptimizer(1.0) + self.common_minimize_trainable(base_opt, + test_opt, + name="proximal_adagrad", + bp_v2=True) + + @test_util.deprecated_graph_mode_only + def test_proximalsgd_minimize_trainable_bpv2(self): + base_opt = pgd.ProximalGradientDescentOptimizer(1.0) + test_opt = pgd.ProximalGradientDescentOptimizer(1.0) + self.common_minimize_trainable(base_opt, + test_opt, + name="proximal_sgd", + bp_v2=True) + + @test_util.deprecated_graph_mode_only + def test_momentum_minimize_trainable_bpv2(self): + base_opt = momentum.MomentumOptimizer(1.0, momentum=0.9) + test_opt = momentum.MomentumOptimizer(1.0, momentum=0.9) + self.common_minimize_trainable(base_opt, + test_opt, + name="momentum", + bp_v2=True) + + @test_util.deprecated_graph_mode_only + def test_sgd_minimize_trainable_bpv2(self): + base_opt = gradient_descent.GradientDescentOptimizer(1.0) + test_opt = gradient_descent.GradientDescentOptimizer(1.0) + self.common_minimize_trainable(base_opt, test_opt, name="sgd", bp_v2=True) + + @test_util.deprecated_graph_mode_only + def test_adam_minimize_trainable_bpv2(self): + base_opt = adam.AdamOptimizer(1.0) + test_opt = adam.AdamOptimizer(1.0) + self.common_minimize_trainable(base_opt, test_opt, name="adam", bp_v2=True) + + @test_util.deprecated_graph_mode_only + def test_rmsprop_minimize_trainable_bpv2(self): + for centered_ in [False, True]: + base_opt = rmsprop.RMSPropOptimizer(1.0, centered=centered_) + test_opt = rmsprop.RMSPropOptimizer(1.0, centered=centered_) + self.common_minimize_trainable(base_opt, + test_opt, + name="rmsprop" + str(centered_), + bp_v2=True) class CommonTrainableTestV2Base(object): @@ -243,7 +346,7 @@ def test_rmsprop_v2_minimize_trainable(self): class EmbeddingLookupTrainableV1Test(test.TestCase, CommonTrainableTestV1Base): - def common_minimize_trainable(self, base_opt, test_opt, name): + def common_minimize_trainable(self, base_opt, test_opt, name, bp_v2): de.enable_train_mode() base_opt = de.DynamicEmbeddingOptimizer(base_opt) test_opt = de.DynamicEmbeddingOptimizer(test_opt) @@ -298,6 +401,7 @@ def common_minimize_trainable(self, base_opt, test_opt, name): devices=_get_devices() * num_shards, initializer=1.0, dim=dim, + bp_v2=bp_v2, ) self.device_check(embeddings) init_ids = constant_op.constant(raw_init_ids, dtype=k_dtype) @@ -436,7 +540,7 @@ def loss_fn(x, trainables): class EmbeddingLookupUniqueTrainableV1Test(test.TestCase, CommonTrainableTestV1Base): - def common_minimize_trainable(self, base_opt, test_opt, name): + def common_minimize_trainable(self, base_opt, test_opt, name, bp_v2): de.enable_train_mode() base_opt = de.DynamicEmbeddingOptimizer(base_opt) test_opt = de.DynamicEmbeddingOptimizer(test_opt) @@ -632,7 +736,7 @@ def loss_fn(x, trainables): class EmbeddingLookupSparseTrainableV1Test(test.TestCase, CommonTrainableTestV1Base): - def common_minimize_trainable(self, base_opt, test_opt, name): + def common_minimize_trainable(self, base_opt, test_opt, name, bp_v2): de.enable_train_mode() base_opt = de.DynamicEmbeddingOptimizer(base_opt) test_opt = de.DynamicEmbeddingOptimizer(test_opt) @@ -889,7 +993,7 @@ class SafeEmbeddingLookupSparseTrainableV1Test(test.TestCase, CommonTrainableTestV1Base): @test_util.deprecated_graph_mode_only - def common_minimize_trainable(self, base_opt, test_opt, name): + def common_minimize_trainable(self, base_opt, test_opt, name, bp_v2): de.enable_train_mode() base_opt = de.DynamicEmbeddingOptimizer(base_opt) test_opt = de.DynamicEmbeddingOptimizer(test_opt) @@ -1190,7 +1294,7 @@ def test_saving_restoring_checkpoint(self): self.device_check(table) - def common_minimize_trainable(self, base_opt, test_opt, name): + def common_minimize_trainable(self, base_opt, test_opt, name, bp_v2): de.enable_train_mode() base_opt = de.DynamicEmbeddingOptimizer(base_opt) test_opt = de.DynamicEmbeddingOptimizer(test_opt) @@ -1318,12 +1422,28 @@ def common_minimize_trainable(self, base_opt, test_opt, name): def test_adam_minimize_trainable(self): base_opt = adam.AdamOptimizer(0.1) test_opt = adam.AdamOptimizer(0.1) - self.common_minimize_trainable(base_opt, test_opt, name="adam") + self.common_minimize_trainable(base_opt, test_opt, name="adam", bp_v2=False) def test_adagrad_minimize_trainable(self): base_opt = adagrad.AdagradOptimizer(0.1) test_opt = adagrad.AdagradOptimizer(0.1) - self.common_minimize_trainable(base_opt, test_opt, name="adagrad") + self.common_minimize_trainable(base_opt, + test_opt, + name="adagrad", + bp_v2=False) + + def test_adam_minimize_trainable_bp_v2(self): + base_opt = adam.AdamOptimizer(0.1) + test_opt = adam.AdamOptimizer(0.1) + self.common_minimize_trainable(base_opt, test_opt, name="adam", bp_v2=True) + + def test_adagrad_minimize_trainable_bp_v2(self): + base_opt = adagrad.AdagradOptimizer(0.1) + test_opt = adagrad.AdagradOptimizer(0.1) + self.common_minimize_trainable(base_opt, + test_opt, + name="adagrad", + bp_v2=True) @test_util.deprecated_graph_mode_only diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py index 9f2a09914..a373b277c 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py @@ -77,6 +77,7 @@ def __init__(self, params, ids, max_norm, *args, **kwargs): """ self.params = params self.ids = ids + self.exists = None self.max_norm = max_norm self.prefetch_values_op = None self.model_mode = kwargs.get("model_mode") @@ -85,7 +86,11 @@ def __init__(self, params, ids, max_norm, *args, **kwargs): def prefetch_values(self): if self.prefetch_values_op is None: - self.prefetch_values_op = self.transform(self.params.lookup(self.ids)) + if self.params.bp_v2: + r, self.exists = self.params.lookup(self.ids, return_exists=True) + self.prefetch_values_op = self.transform(r) + else: + self.prefetch_values_op = self.transform(self.params.lookup(self.ids)) return self.prefetch_values_op def _init_from_args( @@ -340,8 +345,13 @@ def _init_from_args( cached_value=cached_value, ) - def update_op(self): - update_param_op = self.params.upsert(self.ids, self.read_value(False)) + def update_op(self, v0=None): + v1 = self.read_value(False) + if self.params.bp_v2: + assert v0 is not None + update_param_op = self.params.accum(self.ids, v0, v1, self.exists) + else: + update_param_op = self.params.upsert(self.ids, v1) if self.params.restrict_policy is not None: update_status_op = self.params.restrict_policy.apply_update(self.ids) return control_flow_ops.group([update_param_op, update_status_op]) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py index 48a111b92..c7482c268 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py @@ -95,7 +95,10 @@ def apply_grad_to_update_var(var, grad): var.params.restrict_policy._track_optimizer_slots(_slots) with ops.control_dependencies([grad]): - _before = [var.read_value()] + [_s.read_value() for _s in _slots] + v0 = var.read_value(do_prefetch=not var.params.bp_v2) + s0 = [_s.read_value() for _s in _slots] + _before = [v0] + s0 + if isinstance(grad, ops.IndexedSlices): if var.constraint is not None: raise RuntimeError( @@ -106,8 +109,9 @@ def apply_grad_to_update_var(var, grad): _apply_op = self._resource_apply_sparse_duplicate_indices( grad.values, var, grad.indices, **apply_kwargs) with ops.control_dependencies([_apply_op]): - _after = control_flow_ops.group([var.update_op()] + - [_s.update_op() for _s in _slots]) + _after = control_flow_ops.group( + [var.update_op(v0=v0)] + + [_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)]) return _after if "apply_state" in self._dense_apply_args: @@ -119,8 +123,9 @@ def apply_grad_to_update_var(var, grad): return var.assign(var.constraint(var)) else: with ops.control_dependencies([update_op]): - _after = control_flow_ops.group([var.update_op()] + - [_s.update_op() for _s in _slots]) + _after = control_flow_ops.group( + [var.update_op(v0=v0)] + + [_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)]) return _after update_ops = [] @@ -300,6 +305,7 @@ def create_slots(primary, init, slot_name, op_name): init_size=params_var_.init_size, trainable=False, checkpoint=params_var_.checkpoint, + bp_v2=params_var_.bp_v2, ) scope_store._vars[full_name] = slot_variable_ diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py index f3ab4484a..05b5e9a01 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py @@ -151,6 +151,7 @@ def __init__( checkpoint=True, init_size=0, restrict_policy=None, + bp_v2=False, ): """Creates an empty `Variable` object. @@ -196,6 +197,14 @@ def default_partition_fn(keys, shard_num): size of variable. If in training program, the variable is updated by optimizer, then the sparse slot variables in optimizer are also be restricted. + bp_v2: By default with `bp_v2=False`, the optimizer will update + dynamic embedding values by *setting* (key, value) after + `optimizer.apply_gradient`. If one key are used by multiple workers + at the same time, only one of them will be seen, while the others are + overwritten. By setting `bp_v2=True`, the optimizer will update + parameters by *adding delta* instead of *setting*, which solves the + race condition problem among workers during backpropagation in + large-scale distributed asynchronous training. Returns: A `Variable` object. @@ -203,6 +212,7 @@ def default_partition_fn(keys, shard_num): self.key_dtype = key_dtype self.value_dtype = value_dtype self.dim = dim + self.bp_v2 = bp_v2 def _get_default_devices(): gpu_list = [ @@ -592,6 +602,7 @@ def get_variable( checkpoint=True, init_size=0, restrict_policy=None, + bp_v2=False, ): """Gets an `Variable` object with this name if it exists, or create a new one. @@ -628,6 +639,14 @@ def default_partition_fn(keys, shard_num): size of variable. If in training program, the variable is updated by optimizer, then the sparse slot variables in optimizer are also be restricted. + bp_v2: By default with `bp_v2=False`, the optimizer will update + dynamic embedding values by *setting* (key, value) after + `optimizer.apply_gradient`. If one key are used by multiple workers + at the same time, only one of them will be seen, while the others are + overwritten. By setting `bp_v2=True`, the optimizer will update + parameters by *adding delta* instead of *setting*, which solves the + race condition problem among workers during backpropagation in + large-scale distributed asynchronous training. Returns: A `Variable` object. @@ -657,6 +676,7 @@ def default_partition_fn(keys, shard_num): checkpoint=checkpoint, init_size=init_size, restrict_policy=restrict_policy, + bp_v2=bp_v2, ) scope_store._vars[full_name] = var_ return scope_store._vars[full_name] diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py index 97bb9f827..a33676ebd 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py @@ -27,8 +27,8 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops as rvo -from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.training import device_setter from tensorflow.python.training import optimizer from tensorflow.python.training import slot_creator @@ -56,7 +56,10 @@ def update_op(self, optimizer, g): self._v.params.restrict_policy._track_optimizer_slots(_slots) with ops.control_dependencies([g]): - _before = [self._v.read_value()] + [_s.read_value() for _s in _slots] + v0 = self._v.read_value(do_prefetch=not self._v.params.bp_v2) + s0 = [_s.read_value() for _s in _slots] + _before = [v0] + s0 + if isinstance(g, ops.IndexedSlices): if self._v.constraint is not None: raise RuntimeError( @@ -66,9 +69,11 @@ def update_op(self, optimizer, g): _apply_op = optimizer._resource_apply_sparse_duplicate_indices( g.values, self._v, g.indices) with ops.control_dependencies([_apply_op]): - _after = control_flow_ops.group([self._v.update_op()] + - [_s.update_op() for _s in _slots]) + _after = control_flow_ops.group( + [self._v.update_op(v0=v0)] + + [_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)]) return _after + with ops.control_dependencies(_before): _apply_op = optimizer._resource_apply_dense(g, self._v) if self._v.constraint is not None: @@ -76,8 +81,9 @@ def update_op(self, optimizer, g): return self._v.assign(self._v.constraint(self._v)) else: with ops.control_dependencies([_apply_op]): - _after = control_flow_ops.group([self._v.update_op()] + - [_s.update_op() for _s in _slots]) + _after = control_flow_ops.group( + [self._v.update_op(v0=v0)] + + [_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)]) return _after