Skip to content

Commit

Permalink
[demo] add a distributed training demo on movielens for dynamic_embed…
Browse files Browse the repository at this point in the history
…ding
  • Loading branch information
rhdong committed May 6, 2021
1 parent cd07126 commit 7a9323d
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 0 deletions.
18 changes: 18 additions & 0 deletions demo/movielens-1m-ps/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# A distributed training demo for `tfra.dynamic_embedding`:

- dataset: [movielen/1m-ratings](https://www.tensorflow.org/datasets/catalog/movielens#movielens1m-ratings)
- model: DNN
- Running mode: Graph mode by using MonitoredTrainingSession

## start train:
By default, this shell will start a train task with 2 PS and 2 workers on local machine.

```shell
sh start.sh
```

## stop train:

```shell
sh stop.sh
```
214 changes: 214 additions & 0 deletions demo/movielens-1m-ps/movielens-1m-ps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import os, sys
import tensorflow as tf
from tensorflow.keras.layers import Dense
import tensorflow_datasets as tfds
import tensorflow_recommenders_addons as tfra

tf.compat.v1.disable_v2_behavior()
tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_resource_variables()

flags = tf.compat.v1.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string(
'ps_list', "localhost:2220, localhost:2221",
'ps_list: to be a comma seperated string, '
'like "localhost:2220, localhost:2220"')
flags.DEFINE_string(
'worker_list', "localhost:2230",
'worker_list: to be a comma seperated string, '
'like "localhost:2230, localhost:2231"')
flags.DEFINE_string('task_mode', "worker", 'runninig_mode: ps or worker.')
flags.DEFINE_integer('task_id', 0, 'task_id: used for allocating samples.')
flags.DEFINE_bool('is_chief', False, ''
': If true, will run init_op and save/restore.')


class Trainer():

def __init__(self, worker_id, worker_num, ps_num, batch_size, ckpt_dir=None):
self.embedding_size = 32
self.worker_id = worker_id
self.worker_num = worker_num
self.batch_size = batch_size
self.devices = [
"/job:ps/replica:0/task:{}".format(idx) for idx in range(ps_num)
]
self.ckpt_dir = ckpt_dir
if self.ckpt_dir:
os.makedirs(os.path.split(self.ckpt_dir)[0], exist_ok=True)

def read_batch(self):
split_size = int(100 / self.worker_num)
split_start = split_size * self.worker_id
split = 'train[{}%:{}%]'.format(split_start, split_start + split_size - 1)
print("dataset split, worker{}: {}".format(self.worker_id, split))
ratings = tfds.load("movielens/1m-ratings", split=split)
ratings = ratings.map(
lambda x: {
"movie_id": tf.strings.to_number(x["movie_id"], tf.int64),
"user_id": tf.strings.to_number(x["user_id"], tf.int64),
"user_rating": x["user_rating"]
})
shuffled = ratings.shuffle(1_000_000,
seed=2021,
reshuffle_each_iteration=False)
dataset_train = shuffled.batch(self.batch_size)
train_iter = tf.compat.v1.data.make_initializable_iterator(dataset_train)
return train_iter

def build_graph(self, batch):
movie_id = batch["movie_id"]
user_id = batch["user_id"]
rating = batch["user_rating"]

d0 = Dense(256,
activation='relu',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
d1 = Dense(64,
activation='relu',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
d2 = Dense(1,
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
user_embeddings = tfra.dynamic_embedding.get_variable(
name="user_dynamic_embeddings",
dim=self.embedding_size,
devices=self.devices,
initializer=tf.keras.initializers.RandomNormal(-1, 1))
movie_embeddings = tfra.dynamic_embedding.get_variable(
name="moive_dynamic_embeddings",
dim=self.embedding_size,
devices=self.devices,
initializer=tf.keras.initializers.RandomNormal(-1, 1))

user_id_val, user_id_idx = tf.unique(user_id)
user_id_weights, user_id_trainable_wrapper = tfra.dynamic_embedding.embedding_lookup(
params=user_embeddings,
ids=user_id_val,
name="user-id-weights",
return_trainable=True)
user_id_weights = tf.gather(user_id_weights, user_id_idx)

movie_id_val, movie_id_idx = tf.unique(movie_id)
movie_id_weights, movie_id_trainable_wrapper = tfra.dynamic_embedding.embedding_lookup(
params=movie_embeddings,
ids=movie_id_val,
name="movie-id-weights",
return_trainable=True)
movie_id_weights = tf.gather(movie_id_weights, movie_id_idx)

embeddings = tf.concat([user_id_weights, movie_id_weights], axis=1)
dnn = d0(embeddings)
dnn = d1(dnn)
dnn = d2(dnn)
predict = tf.reshape(dnn, shape=[-1])
loss = tf.keras.losses.MeanSquaredError()(rating, predict)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)
optimizer = tfra.dynamic_embedding.DynamicEmbeddingOptimizer(optimizer)
update = optimizer.minimize(
loss, global_step=tf.compat.v1.train.get_or_create_global_step())
return {
"update": update,
"predict": predict,
"loss": loss,
"size": user_embeddings.size(),
}


def start_worker(worker_id, config):
print("worker config", config)
ps_list = config['cluster']['ps']
worker_list = config['cluster']['worker']

num_ps_tasks = len(ps_list)
num_worker_tasks = len(worker_list)
sess_config = tf.compat.v1.ConfigProto()
sess_config.intra_op_parallelism_threads = 1
sess_config.inter_op_parallelism_threads = 1
cluster = tf.train.ClusterSpec(config['cluster'])
server = tf.distribute.Server(cluster,
protocol="grpc",
job_name="worker",
task_index=worker_id,
config=sess_config)
with tf.compat.v1.device("/job:worker/replica:0/task:{}".format(worker_id)):
trainer = Trainer(worker_id=worker_id,
worker_num=num_worker_tasks,
ps_num=num_ps_tasks,
batch_size=64,
ckpt_dir=None)
train_iter = trainer.read_batch()
train_data = train_iter.get_next()

device_setter = tf.compat.v1.train.replica_device_setter(
ps_tasks=num_ps_tasks,
worker_device="/job:worker/replica:0/task:{}".format(worker_id),
ps_device="/job:ps")

with tf.compat.v1.device(device_setter):
outputs = trainer.build_graph(train_data)

with tf.compat.v1.train.MonitoredTrainingSession(
master=server.target,
is_chief=FLAGS.is_chief,
checkpoint_dir=trainer.ckpt_dir if FLAGS.is_chief else None,
config=sess_config,
) as sess:
sess.run([train_iter.initializer])

step = 0
while True:
step += 1
try:
_, _loss, _pred = sess.run(
[outputs["update"], outputs["loss"], outputs["predict"]])

_size = sess.run(outputs["size"])
if step % 100 == 0:
print("[worker{}]step{}:\tloss={:.4f}\t size={}".format(
worker_id, step, float(_loss), _size))
except tf.errors.OutOfRangeError:
print("[worker{}]no more data!".format(worker_id))
break


def start_ps(task_id, config):
print("ps config", config)
cluster = tf.train.ClusterSpec(config["cluster"])

sess_config = tf.compat.v1.ConfigProto()
sess_config.intra_op_parallelism_threads = 1
sess_config.inter_op_parallelism_threads = 1
server = tf.distribute.Server(cluster,
config=sess_config,
protocol='grpc',
job_name="ps",
task_index=task_id)
server.join()


def main(argv):
ps_list = FLAGS.ps_list.replace(' ', '').split(',')
worker_list = FLAGS.worker_list.replace(' ', '').split(',')
task_mode = FLAGS.task_mode
task_id = FLAGS.task_id

print('ps_list: ', ps_list)
print('worker_list: ', worker_list)

if task_mode == 'ps':
ps_config = {"cluster": {"ps": ps_list, "worker": worker_list}}
start_ps(task_id, ps_config)
elif task_mode == 'worker':
worker_config = {"cluster": {"ps": ps_list, "worker": worker_list}}
start_worker(task_id, worker_config)
else:
print('invalid task_mode. Options include "ps" and "worker".')
sys.exit(1)


if __name__ == "__main__":
tf.compat.v1.app.run()
12 changes: 12 additions & 0 deletions demo/movielens-1m-ps/start.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
rm -rf ./ckpt
sh stop.sh
sleep 1
python movielens-1m-ps.py --ps_list="localhost:2220,localhost:2221" --worker_list="localhost:2230,localhost:2231" --task_mode="ps" --task_id=0 &
sleep 1
python movielens-1m-ps.py --ps_list="localhost:2220,localhost:2221" --worker_list="localhost:2230,localhost:2231" --task_mode="ps" --task_id=1 &
sleep 1
python movielens-1m-ps.py --ps_list="localhost:2220,localhost:2221" --worker_list="localhost:2230,localhost:2231" --task_mode="worker" --task_id=1 --is_chief=False &
sleep 1
python movielens-1m-ps.py --ps_list="localhost:2220,localhost:2221" --worker_list="localhost:2230,localhost:2231" --task_mode="worker" --task_id=0 --is_chief=True &
echo "ok"
4 changes: 4 additions & 0 deletions demo/movielens-1m-ps/stop.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
ps -ef|grep "movielens-"|grep -v grep|awk '{print $2}'| xargs kill -9
sleep 1
echo "result"
ps -ef|grep "movielens-"

0 comments on commit 7a9323d

Please sign in to comment.