-
Notifications
You must be signed in to change notification settings - Fork 139
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
493 additions
and
3 deletions.
There are no files selected for viewing
19 changes: 19 additions & 0 deletions
19
demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# A distributed synchronous training demo based on Horovod for `tfra.dynamic_embedding`: | ||
|
||
- dataset: [movielen/100k-ratings](https://www.tensorflow.org/datasets/catalog/movielens#movielens100k-ratings) | ||
- model: DNN | ||
- Running API: using estimator APIs | ||
|
||
## Requirements | ||
- Horovod Version: 0.23.0 | ||
- OpenMPI Version: 4.1.2 | ||
## start train: | ||
By default, this shell will start a train task with 1 PS and 1 workers and 1 chief on local machine. | ||
sh train.sh | ||
|
||
## start export for serving: | ||
By default, this shell will start a export for serving task with 1 PS and 1 workers and 1 chief on local machine. | ||
sh export.sh | ||
|
||
## stop.train | ||
run sh stop.sh |
13 changes: 13 additions & 0 deletions
13
demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/export.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/usr/bin/env bash | ||
rm -rf ./export_dir | ||
sh stop.sh | ||
|
||
sleep 1 | ||
export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240"]}, "task": {"type": "ps", "index": 0}}' | ||
python movielens-100k-estimator.py --mode serving & | ||
|
||
export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240"]}, "task": {"type": "worker", "index": 0}}' | ||
mpirun -np 1 -H localhost:1 --allow-run-as-root -bind-to none -map-by slot -x TF_CONFIG sh -c 'python movielens-100k-estimator.py --mode serving> log/worker_0.log 2>&1' | ||
|
||
echo "ok" | ||
|
215 changes: 215 additions & 0 deletions
215
.../dynamic_embedding/movielens-100k-sync-estimator-with-horovod/movielens-100k-estimator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
import json | ||
import os | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras.layers import Dense | ||
|
||
import tensorflow_datasets as tfds | ||
import tensorflow_recommenders_addons as tfra | ||
|
||
from absl import app | ||
from absl import flags | ||
|
||
from session_hook import CustomSaveHook, HorovodSyncHook | ||
|
||
tf.compat.v1.disable_v2_behavior() | ||
tf.compat.v1.disable_eager_execution() | ||
tf.compat.v1.disable_resource_variables() | ||
|
||
flags.DEFINE_string('model_dir', "./ckpt", 'export_dir') | ||
flags.DEFINE_string('export_dir', "./export_dir", 'export_dir') | ||
flags.DEFINE_string('mode', "train", 'train or export') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
def input_fn(): | ||
ratings = tfds.load("movielens/100k-ratings", split="train") | ||
ratings = ratings.map( | ||
lambda x: { | ||
"movie_id": tf.strings.to_number(x["movie_id"], tf.int64), | ||
"user_id": tf.strings.to_number(x["user_id"], tf.int64), | ||
"user_rating": x["user_rating"] | ||
}) | ||
shuffled = ratings.shuffle(1_000_000, | ||
seed=2021, | ||
reshuffle_each_iteration=False) | ||
dataset = shuffled.batch(256) | ||
return dataset | ||
|
||
|
||
def model_fn(features, labels, mode, params): | ||
embedding_size = 32 | ||
movie_id = features["movie_id"] | ||
user_id = features["user_id"] | ||
rating = features["user_rating"] | ||
|
||
task_idx = params["task_idx"] | ||
|
||
is_training = (mode == tf.estimator.ModeKeys.TRAIN) | ||
|
||
if is_training: | ||
ps_list = [ | ||
"/job:ps/replica:0/task:{}/CPU:0".format(i) | ||
for i in range(params["ps_num"]) | ||
] | ||
initializer = tf.keras.initializers.RandomNormal(-1, 1) | ||
else: | ||
ps_list = ["/job:localhost/replica:0/task:0/CPU:0"] * params["ps_num"] | ||
initializer = tf.keras.initializers.Zeros() | ||
|
||
user_embeddings = tfra.dynamic_embedding.get_variable( | ||
name="user_dynamic_embeddings", | ||
dim=embedding_size, | ||
devices=ps_list, | ||
initializer=initializer) | ||
movie_embeddings = tfra.dynamic_embedding.get_variable( | ||
name="moive_dynamic_embeddings", | ||
dim=embedding_size, | ||
devices=ps_list, | ||
initializer=initializer) | ||
|
||
user_id_val, user_id_idx = tf.unique(tf.concat(user_id, axis=0)) | ||
user_id_weights, user_id_trainable_wrapper = tfra.dynamic_embedding.embedding_lookup( | ||
params=user_embeddings, | ||
ids=user_id_val, | ||
name="user-id-weights", | ||
return_trainable=True) | ||
user_id_weights = tf.gather(user_id_weights, user_id_idx) | ||
|
||
movie_id_val, movie_id_idx = tf.unique(tf.concat(movie_id, axis=0)) | ||
movie_id_weights, movie_id_trainable_wrapper = tfra.dynamic_embedding.embedding_lookup( | ||
params=movie_embeddings, | ||
ids=movie_id_val, | ||
name="movie-id-weights", | ||
return_trainable=True) | ||
movie_id_weights = tf.gather(movie_id_weights, movie_id_idx) | ||
|
||
embeddings = tf.concat([user_id_weights, movie_id_weights], axis=1) | ||
if is_training: | ||
device = "/job:worker/replica:0/task:{}/CPU:0".format(task_idx) | ||
else: | ||
device = None | ||
with tf.device(device): | ||
d0 = Dense(256, | ||
activation='relu', | ||
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1), | ||
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1)) | ||
d1 = Dense(64, | ||
activation='relu', | ||
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1), | ||
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1)) | ||
d2 = Dense(1, | ||
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1), | ||
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1)) | ||
dnn = d0(embeddings) | ||
dnn = d1(dnn) | ||
dnn = d2(dnn) | ||
out = tf.reshape(dnn, shape=[-1]) | ||
loss = tf.keras.losses.MeanSquaredError()(rating, out) | ||
predictions = {"out": out} | ||
|
||
if mode == tf.estimator.ModeKeys.EVAL: | ||
eval_metric_ops = {} | ||
return tf.estimator.EstimatorSpec(mode=mode, | ||
loss=loss, | ||
eval_metric_ops=eval_metric_ops) | ||
|
||
if mode == tf.estimator.ModeKeys.TRAIN: | ||
ckpt_dir = params['ckpt_dir'] | ||
global_step = tf.compat.v1.train.get_or_create_global_step() | ||
save_hook = CustomSaveHook(ckpt_dir, global_step, task_idx) | ||
sync_hook = HorovodSyncHook(device=device) | ||
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001) | ||
optimizer = tfra.dynamic_embedding.DynamicEmbeddingOptimizer( | ||
optimizer, horovod_synchronous=True) | ||
train_op = optimizer.minimize( | ||
loss, global_step=tf.compat.v1.train.get_or_create_global_step()) | ||
return tf.estimator.EstimatorSpec(mode=mode, | ||
predictions=predictions, | ||
loss=loss, | ||
train_op=train_op, | ||
training_hooks=[sync_hook, save_hook]) | ||
|
||
if mode == tf.estimator.ModeKeys.PREDICT: | ||
predictions_for_net = {"out": out} | ||
export_outputs = { | ||
"predict_export_outputs": | ||
tf.estimator.export.PredictOutput(outputs=predictions_for_net) | ||
} | ||
return tf.estimator.EstimatorSpec(mode, | ||
predictions=predictions_for_net, | ||
export_outputs=export_outputs) | ||
|
||
|
||
def train(model_dir, ps_num, task_idx): | ||
model_config = tf.estimator.RunConfig(log_step_count_steps=100, | ||
save_summary_steps=100, | ||
save_checkpoints_steps=None, | ||
save_checkpoints_secs=None, | ||
keep_checkpoint_max=2) | ||
|
||
model_config._is_chief = True | ||
|
||
estimator = tf.estimator.Estimator( | ||
model_fn=model_fn, | ||
model_dir=model_dir, | ||
params={ | ||
"ps_num": ps_num, | ||
"task_idx": task_idx, | ||
"ckpt_dir": model_dir + "/" + "model-ckpt" | ||
}, | ||
config=model_config) | ||
|
||
train_spec = tf.estimator.TrainSpec(input_fn=input_fn) | ||
|
||
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn) | ||
|
||
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) | ||
|
||
|
||
def serving_input_receiver_dense_fn(): | ||
input_spec = { | ||
"movie_id": tf.constant([1], tf.int64), | ||
"user_id": tf.constant([1], tf.int64), | ||
"user_rating": tf.constant([1.0], tf.float32) | ||
} | ||
return tf.estimator.export.build_raw_serving_input_receiver_fn(input_spec) | ||
|
||
|
||
def export_for_serving(model_dir, export_dir, ps_num, task_idx): | ||
model_config = tf.estimator.RunConfig(log_step_count_steps=100, | ||
save_summary_steps=100, | ||
save_checkpoints_steps=None, | ||
save_checkpoints_secs=None) | ||
|
||
estimator = tf.estimator.Estimator( | ||
model_fn=model_fn, | ||
model_dir=model_dir, | ||
params={ | ||
"ps_num": ps_num, | ||
"task_idx": task_idx, | ||
"ckpt_dir": model_dir + "/" + "model-ckpt" | ||
}, | ||
config=model_config) | ||
|
||
estimator.export_saved_model(export_dir, serving_input_receiver_dense_fn()) | ||
|
||
|
||
def main(argv): | ||
del argv | ||
tf_config = json.loads(os.environ.get('TF_CONFIG') or '{}') | ||
task_name = tf_config.get('task', {}).get('type') | ||
task_idx = tf_config.get('task', {}).get('index') | ||
|
||
ps_num = len(tf_config["cluster"]["ps"]) | ||
|
||
if FLAGS.mode == "train": | ||
train(FLAGS.model_dir, ps_num, task_idx) | ||
if FLAGS.mode == "serving" and int(task_idx) == 0 and task_name == "worker": | ||
tfra.dynamic_embedding.enable_inference_mode() | ||
export_for_serving(FLAGS.model_dir, FLAGS.export_dir, ps_num, task_idx) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
42 changes: 42 additions & 0 deletions
42
demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/session_hook.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import horovod.tensorflow as hvd | ||
import tensorflow as tf | ||
|
||
|
||
class CustomSaveHook(tf.compat.v1.train.SessionRunHook): | ||
|
||
def __init__(self, ckpt_dir, global_step, worker_id): | ||
self._ckpt_dir = ckpt_dir | ||
self._global_step = global_step | ||
self._saver = tf.compat.v1.train.Saver(sharded=True, | ||
allow_empty=True, | ||
max_to_keep=1) | ||
self._worker_id = worker_id | ||
super(CustomSaveHook, self).__init__() | ||
|
||
def end(self, session): | ||
global_step = session.run(self._global_step) | ||
if self._worker_id == 0 and self._ckpt_dir: | ||
# only save checkpoint once when the train is finished. | ||
self._saver.save(session, self._ckpt_dir, global_step) | ||
|
||
|
||
class HorovodSyncHook(tf.compat.v1.train.SessionRunHook): | ||
|
||
def __init__(self, device=''): | ||
hvd.init() | ||
with tf.device(device): | ||
self._bcast_op = hvd.broadcast_global_variables(0) | ||
self._exit_op = hvd.join() | ||
|
||
self._broadcast_done = False | ||
super(HorovodSyncHook, self).__init__() | ||
|
||
def after_run(self, run_context, run_values): | ||
if self._broadcast_done: | ||
return | ||
run_context.session.run(self._bcast_op) | ||
|
||
self._broadcast_done = True | ||
|
||
def end(self, session): | ||
session.run(self._exit_op) |
5 changes: 5 additions & 0 deletions
5
demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/start_worker.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env bash | ||
sleep 1 | ||
TASK_INEDX=$(($OMPI_COMM_WORLD_RANK)) | ||
export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240", "localhost:2241"]}, "task": {"type": "worker", "index": '"${TASK_INEDX}"'}}' | ||
python movielens-100k-estimator.py --mode train |
4 changes: 4 additions & 0 deletions
4
demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/stop.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
ps -ef|grep "movielens-"|grep -v grep|awk '{print $2}'| xargs kill -9 | ||
sleep 1 | ||
echo "result" | ||
ps -ef|grep "movielens-" |
12 changes: 12 additions & 0 deletions
12
demo/dynamic_embedding/movielens-100k-sync-estimator-with-horovod/train.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/usr/bin/env bash | ||
rm -rf ./ckpt | ||
rm -rf ./export_dir | ||
sh stop.sh | ||
|
||
sleep 1 | ||
export TF_CONFIG='{"cluster": {"ps": ["localhost:2223"], "chief": ["localhost:2228"], "worker":["localhost:2240", "localhost:2241"]}, "task": {"type": "ps", "index": 0}}' | ||
python movielens-100k-estimator.py --mode train & | ||
|
||
mpirun -np 2 -H localhost:2 --allow-run-as-root -bind-to none -map-by slot sh -c './start_worker.sh > log/worker_$OMPI_COMM_WORLD_RANK.log 2>&1' | ||
|
||
echo "ok" |
Oops, something went wrong.