Skip to content

Commit

Permalink
support horovod sync train
Browse files Browse the repository at this point in the history
  • Loading branch information
a6802739 authored and Lifann committed Feb 16, 2022
1 parent 80697d1 commit 6872e62
Show file tree
Hide file tree
Showing 10 changed files with 493 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# A distributed synchronous training demo based on Horovod for `tfra.dynamic_embedding`:

- dataset: [movielen/100k-ratings](https://www.tensorflow.org/datasets/catalog/movielens#movielens100k-ratings)
- model: DNN
- Running API: using estimator APIs

## Requirements
- Horovod Version: 0.23.0
- OpenMPI Version: 4.1.2
## start train:
By default, this shell will start a train task with 1 PS and 1 workers and 1 chief on local machine.
sh train.sh

## start export for serving:
By default, this shell will start a export for serving task with 1 PS and 1 workers and 1 chief on local machine.
sh export.sh

## stop.train
run sh stop.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env bash
rm -rf ./export_dir
sh stop.sh

sleep 1
export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240"]}, "task": {"type": "ps", "index": 0}}'
python movielens-100k-estimator.py --mode serving &

export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240"]}, "task": {"type": "worker", "index": 0}}'
mpirun -np 1 -H localhost:1 --allow-run-as-root -bind-to none -map-by slot -x TF_CONFIG sh -c 'python movielens-100k-estimator.py --mode serving> log/worker_0.log 2>&1'

echo "ok"

Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import json
import os

import tensorflow as tf
from tensorflow.keras.layers import Dense

import tensorflow_datasets as tfds
import tensorflow_recommenders_addons as tfra

from absl import app
from absl import flags

from session_hook import CustomSaveHook, HorovodSyncHook

tf.compat.v1.disable_v2_behavior()
tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_resource_variables()

flags.DEFINE_string('model_dir', "./ckpt", 'export_dir')
flags.DEFINE_string('export_dir', "./export_dir", 'export_dir')
flags.DEFINE_string('mode', "train", 'train or export')

FLAGS = flags.FLAGS


def input_fn():
ratings = tfds.load("movielens/100k-ratings", split="train")
ratings = ratings.map(
lambda x: {
"movie_id": tf.strings.to_number(x["movie_id"], tf.int64),
"user_id": tf.strings.to_number(x["user_id"], tf.int64),
"user_rating": x["user_rating"]
})
shuffled = ratings.shuffle(1_000_000,
seed=2021,
reshuffle_each_iteration=False)
dataset = shuffled.batch(256)
return dataset


def model_fn(features, labels, mode, params):
embedding_size = 32
movie_id = features["movie_id"]
user_id = features["user_id"]
rating = features["user_rating"]

task_idx = params["task_idx"]

is_training = (mode == tf.estimator.ModeKeys.TRAIN)

if is_training:
ps_list = [
"/job:ps/replica:0/task:{}/CPU:0".format(i)
for i in range(params["ps_num"])
]
initializer = tf.keras.initializers.RandomNormal(-1, 1)
else:
ps_list = ["/job:localhost/replica:0/task:0/CPU:0"] * params["ps_num"]
initializer = tf.keras.initializers.Zeros()

user_embeddings = tfra.dynamic_embedding.get_variable(
name="user_dynamic_embeddings",
dim=embedding_size,
devices=ps_list,
initializer=initializer)
movie_embeddings = tfra.dynamic_embedding.get_variable(
name="moive_dynamic_embeddings",
dim=embedding_size,
devices=ps_list,
initializer=initializer)

user_id_val, user_id_idx = tf.unique(tf.concat(user_id, axis=0))
user_id_weights, user_id_trainable_wrapper = tfra.dynamic_embedding.embedding_lookup(
params=user_embeddings,
ids=user_id_val,
name="user-id-weights",
return_trainable=True)
user_id_weights = tf.gather(user_id_weights, user_id_idx)

movie_id_val, movie_id_idx = tf.unique(tf.concat(movie_id, axis=0))
movie_id_weights, movie_id_trainable_wrapper = tfra.dynamic_embedding.embedding_lookup(
params=movie_embeddings,
ids=movie_id_val,
name="movie-id-weights",
return_trainable=True)
movie_id_weights = tf.gather(movie_id_weights, movie_id_idx)

embeddings = tf.concat([user_id_weights, movie_id_weights], axis=1)
if is_training:
device = "/job:worker/replica:0/task:{}/CPU:0".format(task_idx)
else:
device = None
with tf.device(device):
d0 = Dense(256,
activation='relu',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
d1 = Dense(64,
activation='relu',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
d2 = Dense(1,
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
dnn = d0(embeddings)
dnn = d1(dnn)
dnn = d2(dnn)
out = tf.reshape(dnn, shape=[-1])
loss = tf.keras.losses.MeanSquaredError()(rating, out)
predictions = {"out": out}

if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = {}
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
eval_metric_ops=eval_metric_ops)

if mode == tf.estimator.ModeKeys.TRAIN:
ckpt_dir = params['ckpt_dir']
global_step = tf.compat.v1.train.get_or_create_global_step()
save_hook = CustomSaveHook(ckpt_dir, global_step, task_idx)
sync_hook = HorovodSyncHook(device=device)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)
optimizer = tfra.dynamic_embedding.DynamicEmbeddingOptimizer(
optimizer, horovod_synchronous=True)
train_op = optimizer.minimize(
loss, global_step=tf.compat.v1.train.get_or_create_global_step())
return tf.estimator.EstimatorSpec(mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
training_hooks=[sync_hook, save_hook])

if mode == tf.estimator.ModeKeys.PREDICT:
predictions_for_net = {"out": out}
export_outputs = {
"predict_export_outputs":
tf.estimator.export.PredictOutput(outputs=predictions_for_net)
}
return tf.estimator.EstimatorSpec(mode,
predictions=predictions_for_net,
export_outputs=export_outputs)


def train(model_dir, ps_num, task_idx):
model_config = tf.estimator.RunConfig(log_step_count_steps=100,
save_summary_steps=100,
save_checkpoints_steps=None,
save_checkpoints_secs=None,
keep_checkpoint_max=2)

model_config._is_chief = True

estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
params={
"ps_num": ps_num,
"task_idx": task_idx,
"ckpt_dir": model_dir + "/" + "model-ckpt"
},
config=model_config)

train_spec = tf.estimator.TrainSpec(input_fn=input_fn)

eval_spec = tf.estimator.EvalSpec(input_fn=input_fn)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)


def serving_input_receiver_dense_fn():
input_spec = {
"movie_id": tf.constant([1], tf.int64),
"user_id": tf.constant([1], tf.int64),
"user_rating": tf.constant([1.0], tf.float32)
}
return tf.estimator.export.build_raw_serving_input_receiver_fn(input_spec)


def export_for_serving(model_dir, export_dir, ps_num, task_idx):
model_config = tf.estimator.RunConfig(log_step_count_steps=100,
save_summary_steps=100,
save_checkpoints_steps=None,
save_checkpoints_secs=None)

estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
params={
"ps_num": ps_num,
"task_idx": task_idx,
"ckpt_dir": model_dir + "/" + "model-ckpt"
},
config=model_config)

estimator.export_saved_model(export_dir, serving_input_receiver_dense_fn())


def main(argv):
del argv
tf_config = json.loads(os.environ.get('TF_CONFIG') or '{}')
task_name = tf_config.get('task', {}).get('type')
task_idx = tf_config.get('task', {}).get('index')

ps_num = len(tf_config["cluster"]["ps"])

if FLAGS.mode == "train":
train(FLAGS.model_dir, ps_num, task_idx)
if FLAGS.mode == "serving" and int(task_idx) == 0 and task_name == "worker":
tfra.dynamic_embedding.enable_inference_mode()
export_for_serving(FLAGS.model_dir, FLAGS.export_dir, ps_num, task_idx)


if __name__ == "__main__":
app.run(main)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import horovod.tensorflow as hvd
import tensorflow as tf


class CustomSaveHook(tf.compat.v1.train.SessionRunHook):

def __init__(self, ckpt_dir, global_step, worker_id):
self._ckpt_dir = ckpt_dir
self._global_step = global_step
self._saver = tf.compat.v1.train.Saver(sharded=True,
allow_empty=True,
max_to_keep=1)
self._worker_id = worker_id
super(CustomSaveHook, self).__init__()

def end(self, session):
global_step = session.run(self._global_step)
if self._worker_id == 0 and self._ckpt_dir:
# only save checkpoint once when the train is finished.
self._saver.save(session, self._ckpt_dir, global_step)


class HorovodSyncHook(tf.compat.v1.train.SessionRunHook):

def __init__(self, device=''):
hvd.init()
with tf.device(device):
self._bcast_op = hvd.broadcast_global_variables(0)
self._exit_op = hvd.join()

self._broadcast_done = False
super(HorovodSyncHook, self).__init__()

def after_run(self, run_context, run_values):
if self._broadcast_done:
return
run_context.session.run(self._bcast_op)

self._broadcast_done = True

def end(self, session):
session.run(self._exit_op)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env bash
sleep 1
TASK_INEDX=$(($OMPI_COMM_WORLD_RANK))
export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240", "localhost:2241"]}, "task": {"type": "worker", "index": '"${TASK_INEDX}"'}}'
python movielens-100k-estimator.py --mode train
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
ps -ef|grep "movielens-"|grep -v grep|awk '{print $2}'| xargs kill -9
sleep 1
echo "result"
ps -ef|grep "movielens-"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env bash
rm -rf ./ckpt
rm -rf ./export_dir
sh stop.sh

sleep 1
export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240", "localhost:2241"]}, "task": {"type": "ps", "index": 0}}'
python movielens-100k-estimator.py --mode train &

mpirun -np 2 -H localhost:2 --allow-run-as-root -bind-to none -map-by slot sh -c './start_worker.sh > log/worker_$OMPI_COMM_WORLD_RANK.log 2>&1'

echo "ok"
Loading

0 comments on commit 6872e62

Please sign in to comment.