From 6872e625f344eb9a7f4c419a3a8a6417e6844d9a Mon Sep 17 00:00:00 2001 From: a6802739 Date: Wed, 29 Dec 2021 11:55:08 +0800 Subject: [PATCH] support horovod sync train --- .../README.md | 19 ++ .../export.sh | 13 ++ .../movielens-100k-estimator.py | 215 ++++++++++++++++++ .../session_hook.py | 42 ++++ .../start_worker.sh | 5 + .../stop.sh | 4 + .../train.sh | 12 + .../kernel_tests/horovod_sync_train_test.py | 105 +++++++++ .../python/ops/dynamic_embedding_optimizer.py | 78 ++++++- tools/testing/build_and_run_tests.sh | 3 + 10 files changed, 493 insertions(+), 3 deletions(-) create mode 100644 demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/README.md create mode 100755 demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/export.sh create mode 100644 demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/movielens-100k-estimator.py create mode 100644 demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/session_hook.py create mode 100644 demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/start_worker.sh create mode 100644 demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/stop.sh create mode 100644 demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/train.sh create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py diff --git a/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/README.md b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/README.md new file mode 100644 index 000000000..62b3dcca6 --- /dev/null +++ b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/README.md @@ -0,0 +1,19 @@ +# A distributed synchronous training demo based on Horovod for `tfra.dynamic_embedding`: + +- dataset: [movielen/100k-ratings](https://www.tensorflow.org/datasets/catalog/movielens#movielens100k-ratings) +- model: DNN +- Running API: using estimator APIs + +## Requirements +- Horovod Version: 0.23.0 +- OpenMPI Version: 4.1.2 +## start train: +By default, this shell will start a train task with 1 PS and 1 workers and 1 chief on local machine. +sh train.sh + +## start export for serving: +By default, this shell will start a export for serving task with 1 PS and 1 workers and 1 chief on local machine. +sh export.sh + +## stop.train +run sh stop.sh \ No newline at end of file diff --git a/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/export.sh b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/export.sh new file mode 100755 index 000000000..81f4df2bf --- /dev/null +++ b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/export.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +rm -rf ./export_dir +sh stop.sh + +sleep 1 +export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240"]}, "task": {"type": "ps", "index": 0}}' +python movielens-100k-estimator.py --mode serving & + +export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240"]}, "task": {"type": "worker", "index": 0}}' +mpirun -np 1 -H localhost:1 --allow-run-as-root -bind-to none -map-by slot -x TF_CONFIG sh -c 'python movielens-100k-estimator.py --mode serving> log/worker_0.log 2>&1' + +echo "ok" + diff --git a/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/movielens-100k-estimator.py b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/movielens-100k-estimator.py new file mode 100644 index 000000000..ce1f15b6a --- /dev/null +++ b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/movielens-100k-estimator.py @@ -0,0 +1,215 @@ +import json +import os + +import tensorflow as tf +from tensorflow.keras.layers import Dense + +import tensorflow_datasets as tfds +import tensorflow_recommenders_addons as tfra + +from absl import app +from absl import flags + +from session_hook import CustomSaveHook, HorovodSyncHook + +tf.compat.v1.disable_v2_behavior() +tf.compat.v1.disable_eager_execution() +tf.compat.v1.disable_resource_variables() + +flags.DEFINE_string('model_dir', "./ckpt", 'export_dir') +flags.DEFINE_string('export_dir', "./export_dir", 'export_dir') +flags.DEFINE_string('mode', "train", 'train or export') + +FLAGS = flags.FLAGS + + +def input_fn(): + ratings = tfds.load("movielens/100k-ratings", split="train") + ratings = ratings.map( + lambda x: { + "movie_id": tf.strings.to_number(x["movie_id"], tf.int64), + "user_id": tf.strings.to_number(x["user_id"], tf.int64), + "user_rating": x["user_rating"] + }) + shuffled = ratings.shuffle(1_000_000, + seed=2021, + reshuffle_each_iteration=False) + dataset = shuffled.batch(256) + return dataset + + +def model_fn(features, labels, mode, params): + embedding_size = 32 + movie_id = features["movie_id"] + user_id = features["user_id"] + rating = features["user_rating"] + + task_idx = params["task_idx"] + + is_training = (mode == tf.estimator.ModeKeys.TRAIN) + + if is_training: + ps_list = [ + "/job:ps/replica:0/task:{}/CPU:0".format(i) + for i in range(params["ps_num"]) + ] + initializer = tf.keras.initializers.RandomNormal(-1, 1) + else: + ps_list = ["/job:localhost/replica:0/task:0/CPU:0"] * params["ps_num"] + initializer = tf.keras.initializers.Zeros() + + user_embeddings = tfra.dynamic_embedding.get_variable( + name="user_dynamic_embeddings", + dim=embedding_size, + devices=ps_list, + initializer=initializer) + movie_embeddings = tfra.dynamic_embedding.get_variable( + name="moive_dynamic_embeddings", + dim=embedding_size, + devices=ps_list, + initializer=initializer) + + user_id_val, user_id_idx = tf.unique(tf.concat(user_id, axis=0)) + user_id_weights, user_id_trainable_wrapper = tfra.dynamic_embedding.embedding_lookup( + params=user_embeddings, + ids=user_id_val, + name="user-id-weights", + return_trainable=True) + user_id_weights = tf.gather(user_id_weights, user_id_idx) + + movie_id_val, movie_id_idx = tf.unique(tf.concat(movie_id, axis=0)) + movie_id_weights, movie_id_trainable_wrapper = tfra.dynamic_embedding.embedding_lookup( + params=movie_embeddings, + ids=movie_id_val, + name="movie-id-weights", + return_trainable=True) + movie_id_weights = tf.gather(movie_id_weights, movie_id_idx) + + embeddings = tf.concat([user_id_weights, movie_id_weights], axis=1) + if is_training: + device = "/job:worker/replica:0/task:{}/CPU:0".format(task_idx) + else: + device = None + with tf.device(device): + d0 = Dense(256, + activation='relu', + kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1), + bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1)) + d1 = Dense(64, + activation='relu', + kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1), + bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1)) + d2 = Dense(1, + kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1), + bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1)) + dnn = d0(embeddings) + dnn = d1(dnn) + dnn = d2(dnn) + out = tf.reshape(dnn, shape=[-1]) + loss = tf.keras.losses.MeanSquaredError()(rating, out) + predictions = {"out": out} + + if mode == tf.estimator.ModeKeys.EVAL: + eval_metric_ops = {} + return tf.estimator.EstimatorSpec(mode=mode, + loss=loss, + eval_metric_ops=eval_metric_ops) + + if mode == tf.estimator.ModeKeys.TRAIN: + ckpt_dir = params['ckpt_dir'] + global_step = tf.compat.v1.train.get_or_create_global_step() + save_hook = CustomSaveHook(ckpt_dir, global_step, task_idx) + sync_hook = HorovodSyncHook(device=device) + optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001) + optimizer = tfra.dynamic_embedding.DynamicEmbeddingOptimizer( + optimizer, horovod_synchronous=True) + train_op = optimizer.minimize( + loss, global_step=tf.compat.v1.train.get_or_create_global_step()) + return tf.estimator.EstimatorSpec(mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + training_hooks=[sync_hook, save_hook]) + + if mode == tf.estimator.ModeKeys.PREDICT: + predictions_for_net = {"out": out} + export_outputs = { + "predict_export_outputs": + tf.estimator.export.PredictOutput(outputs=predictions_for_net) + } + return tf.estimator.EstimatorSpec(mode, + predictions=predictions_for_net, + export_outputs=export_outputs) + + +def train(model_dir, ps_num, task_idx): + model_config = tf.estimator.RunConfig(log_step_count_steps=100, + save_summary_steps=100, + save_checkpoints_steps=None, + save_checkpoints_secs=None, + keep_checkpoint_max=2) + + model_config._is_chief = True + + estimator = tf.estimator.Estimator( + model_fn=model_fn, + model_dir=model_dir, + params={ + "ps_num": ps_num, + "task_idx": task_idx, + "ckpt_dir": model_dir + "/" + "model-ckpt" + }, + config=model_config) + + train_spec = tf.estimator.TrainSpec(input_fn=input_fn) + + eval_spec = tf.estimator.EvalSpec(input_fn=input_fn) + + tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + + +def serving_input_receiver_dense_fn(): + input_spec = { + "movie_id": tf.constant([1], tf.int64), + "user_id": tf.constant([1], tf.int64), + "user_rating": tf.constant([1.0], tf.float32) + } + return tf.estimator.export.build_raw_serving_input_receiver_fn(input_spec) + + +def export_for_serving(model_dir, export_dir, ps_num, task_idx): + model_config = tf.estimator.RunConfig(log_step_count_steps=100, + save_summary_steps=100, + save_checkpoints_steps=None, + save_checkpoints_secs=None) + + estimator = tf.estimator.Estimator( + model_fn=model_fn, + model_dir=model_dir, + params={ + "ps_num": ps_num, + "task_idx": task_idx, + "ckpt_dir": model_dir + "/" + "model-ckpt" + }, + config=model_config) + + estimator.export_saved_model(export_dir, serving_input_receiver_dense_fn()) + + +def main(argv): + del argv + tf_config = json.loads(os.environ.get('TF_CONFIG') or '{}') + task_name = tf_config.get('task', {}).get('type') + task_idx = tf_config.get('task', {}).get('index') + + ps_num = len(tf_config["cluster"]["ps"]) + + if FLAGS.mode == "train": + train(FLAGS.model_dir, ps_num, task_idx) + if FLAGS.mode == "serving" and int(task_idx) == 0 and task_name == "worker": + tfra.dynamic_embedding.enable_inference_mode() + export_for_serving(FLAGS.model_dir, FLAGS.export_dir, ps_num, task_idx) + + +if __name__ == "__main__": + app.run(main) diff --git a/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/session_hook.py b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/session_hook.py new file mode 100644 index 000000000..5ece2877d --- /dev/null +++ b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/session_hook.py @@ -0,0 +1,42 @@ +import horovod.tensorflow as hvd +import tensorflow as tf + + +class CustomSaveHook(tf.compat.v1.train.SessionRunHook): + + def __init__(self, ckpt_dir, global_step, worker_id): + self._ckpt_dir = ckpt_dir + self._global_step = global_step + self._saver = tf.compat.v1.train.Saver(sharded=True, + allow_empty=True, + max_to_keep=1) + self._worker_id = worker_id + super(CustomSaveHook, self).__init__() + + def end(self, session): + global_step = session.run(self._global_step) + if self._worker_id == 0 and self._ckpt_dir: + # only save checkpoint once when the train is finished. + self._saver.save(session, self._ckpt_dir, global_step) + + +class HorovodSyncHook(tf.compat.v1.train.SessionRunHook): + + def __init__(self, device=''): + hvd.init() + with tf.device(device): + self._bcast_op = hvd.broadcast_global_variables(0) + self._exit_op = hvd.join() + + self._broadcast_done = False + super(HorovodSyncHook, self).__init__() + + def after_run(self, run_context, run_values): + if self._broadcast_done: + return + run_context.session.run(self._bcast_op) + + self._broadcast_done = True + + def end(self, session): + session.run(self._exit_op) diff --git a/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/start_worker.sh b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/start_worker.sh new file mode 100644 index 000000000..71ec7fa26 --- /dev/null +++ b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/start_worker.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +sleep 1 +TASK_INEDX=$(($OMPI_COMM_WORLD_RANK)) +export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240", "localhost:2241"]}, "task": {"type": "worker", "index": '"${TASK_INEDX}"'}}' +python movielens-100k-estimator.py --mode train diff --git a/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/stop.sh b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/stop.sh new file mode 100644 index 000000000..c9532e8f7 --- /dev/null +++ b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/stop.sh @@ -0,0 +1,4 @@ +ps -ef|grep "movielens-"|grep -v grep|awk '{print $2}'| xargs kill -9 +sleep 1 +echo "result" +ps -ef|grep "movielens-" diff --git a/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/train.sh b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/train.sh new file mode 100644 index 000000000..ae8fbf228 --- /dev/null +++ b/demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/train.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +rm -rf ./ckpt +rm -rf ./export_dir +sh stop.sh + +sleep 1 +export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240", "localhost:2241"]}, "task": {"type": "ps", "index": 0}}' +python movielens-100k-estimator.py --mode train & + +mpirun -np 2 -H localhost:2 --allow-run-as-root -bind-to none -map-by slot sh -c './start_worker.sh > log/worker_$OMPI_COMM_WORLD_RANK.log 2>&1' + +echo "ok" diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py new file mode 100644 index 000000000..8714e6241 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py @@ -0,0 +1,105 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""unit tests of dynamic embedding optimizer ops +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import pytest +import tensorflow as tf +import horovod.tensorflow as hvd + +from tensorflow_recommenders_addons import dynamic_embedding as de + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training_util + +default_config = config_pb2.ConfigProto( + allow_soft_placement=True, + gpu_options=config_pb2.GPUOptions(allow_growth=True)) + + +class HorovodTest(test.TestCase): + + @test_util.deprecated_graph_mode_only + def test_adam_minimize_trainable(self): + base_opt = adam.AdamOptimizer(1.0) + test_opt = adam.AdamOptimizer(1.0) + self.common_minimize_trainable(base_opt, test_opt, name="adam") + + def common_minimize_trainable(self, base_opt, test_opt, name): + tf.config.set_soft_device_placement(True) + hvd.init() + base_opt = de.DynamicEmbeddingOptimizer(base_opt, synchronous=True) + for dtype, run_step, dim in itertools.product([dtypes.float32], [1], [10]): + x = tf.random.uniform(shape=[32, dim]) + y = tf.zeros([32, 1]) + + global_step = training_util.create_global_step() + + base_weight = tf.compat.v1.get_variable(name="base_weights", + initializer=tf.ones([10, 1])) + + base_logits = tf.nn.relu(math_ops.matmul(x, base_weight)) + base_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, + logits=base_logits) + + base_opt_op = base_opt.minimize(base_loss, + global_step, + var_list=[base_weight]) + + test_weight = tf.compat.v1.get_variable(name="test_weights", + initializer=tf.ones([10, 1])) + + test_logits = tf.nn.relu(math_ops.matmul(x, test_weight)) + test_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, + logits=test_logits) + + grads_and_vars = test_opt.compute_gradients(test_loss, + var_list=[test_weight]) + var_list = [] + aggregated_grad = [] + for grad, var in grads_and_vars: + var_list.append(var) + aggregated_grad.append(hvd.allreduce(grad, op=hvd.Sum)) + aggregated_grads_and_vars = zip(aggregated_grad, var_list) + test_opt_op = test_opt.apply_gradients(aggregated_grads_and_vars, + global_step) + + with monitored_session.MonitoredTrainingSession( + is_chief=True, config=default_config) as sess: + + for _ in range(run_step): + sess.run(base_opt_op) + sess.run(test_opt_op) + + self.assertAllCloseAccordingToType( + sess.run(base_weight), + sess.run(test_weight), + msg="Cond:{},{},{}".format(dtype, run_step, dim), + ) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py index bf4d56e6d..4f2df0c5e 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py @@ -41,14 +41,18 @@ from tensorflow.python.training.tracking import base as trackable -def DynamicEmbeddingOptimizer(self, bp_v2=None): +def DynamicEmbeddingOptimizer(self, bp_v2=False, synchronous=False): """ An optimizer wrapper to make any TensorFlow optimizer capable of training Dynamic Embeddding Variables. Args: self: a TensorFlow optimizer. - bp_v2: By default is None, If None use params_var_.bp_v2 setting - (see `tfra.dynamic_embedding_variable.get_variable`) + bp_v2: If True, updating parameters will use updating instead of setting, which solves + the race condition problem among workers during back-propagation in large-scale + distributed asynchronous training. Reference: https://www.usenix.org/system/files/osdi20-jiang.pdf + synchronous: If True, we will use horovod's all-reduce method to merge the dense grad of model parameter, + the default reduce method is SUM. For TrainableWrapper's grad, keep same with before. + Example usage: ```python @@ -60,6 +64,9 @@ def DynamicEmbeddingOptimizer(self, bp_v2=None): The optimizer itself but has ability to train Dynamic Embedding Variables. """ self._bp_v2 = bp_v2 + self._hvd_sync = synchronous + + original_apply_gradients = self.apply_gradients def _distributed_apply(distribution, grads_and_vars, name, apply_state): """`apply_gradients` using a `DistributionStrategy`.""" @@ -281,12 +288,77 @@ def _zeros_slot(var, slot_name, op_name): named_slots[optimizer._var_key(var)] = new_slot_variable return named_slots[optimizer._var_key(var)] + def apply_gradients(grads_and_vars, global_step=None, name=None): + """Apply gradients to variables. + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + compute_gradients(). + global_step: Optional Variable to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the Optimizer constructor. + + Returns: + train_op: apply gradients op to be executed by each replica. + + Raises: + ValueError: If the grads_and_vars is empty. + ValueError: If global step is not provided, the staleness cannot be + checked. + """ + try: + import horovod.tensorflow as hvd + except ImportError: + raise ValueError( + "Please install Horovod first if you want to use distributed synchronous training based on Horovod" + ) + if not grads_and_vars: + raise ValueError("Must supply at least one variable") + + if global_step is None: + raise ValueError("Global step is required to check staleness") + + trainable_grad_and_vars = [] + aggregated_grad = [] + var_list = [] + + with backend.name_scope(name or self._name): + for grad, var in grads_and_vars: + if isinstance(var, de.TrainableWrapper): + trainable_grad_and_vars.append((grad, var)) + continue + var_list.append(var) + with ops.device(var.device): + # Dense gradients. + if grad is None: + aggregated_grad.append(None) # pass-through. + continue + else: + aggregated_grad.append(hvd.allreduce(grad, op=hvd.Sum)) + + aggregated_grads_and_vars = zip(aggregated_grad, var_list) + update_op = original_apply_gradients(aggregated_grads_and_vars, + global_step) + if trainable_grad_and_vars: + trainable_update_op = original_apply_gradients(trainable_grad_and_vars, + global_step) + train_op = control_flow_ops.group([update_op, trainable_update_op]) + else: + train_op = update_op + return train_op + if isinstance(self, optimizer.Optimizer): self._get_or_make_slot = _get_or_make_slot self._get_or_make_slot_with_initializer = _get_or_make_slot_with_initializer self._zeros_slot = _zeros_slot + if self._hvd_sync: + self.apply_gradients = apply_gradients elif isinstance(self, optimizer_v2.OptimizerV2) or isinstance( self, keras_optimizer): + if self._hvd_sync: + raise Exception( + "OptimizerV2 didn't support distributed sync train now, please use tf.train.XxxxOptimizer." + ) self.add_slot = add_slot self._distributed_apply = _distributed_apply else: diff --git a/tools/testing/build_and_run_tests.sh b/tools/testing/build_and_run_tests.sh index 04b85d624..3c231a296 100644 --- a/tools/testing/build_and_run_tests.sh +++ b/tools/testing/build_and_run_tests.sh @@ -33,4 +33,7 @@ if ! [ -x "$(command -v nvidia-smi)" ]; then EXTRA_ARGS="-n auto" fi +mpirun -np 2 -H localhost:2 --allow-run-as-root pytest -v ./tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py + python -m pytest -v -s --functions-durations=20 --modules-durations=5 $EXTRA_ARGS ./tensorflow_recommenders_addons +