Skip to content

Commit

Permalink
[feat] bpv2 (update de by adding delta instead of setting)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrailg committed Aug 10, 2021
1 parent f7fbbe9 commit c8b84e4
Show file tree
Hide file tree
Showing 10 changed files with 330 additions and 34 deletions.
6 changes: 6 additions & 0 deletions demo/bpv2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# A simple training demo for `bp_v2`

## start train:
```
sh train.sh
```
117 changes: 117 additions & 0 deletions demo/bpv2/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
import numpy as np
import tensorflow as tf
import tensorflow_recommenders_addons as tfra

batch_size = 128
vocab_size = 10000
embed_size = 64


def make_word_index():
word_index = tf.keras.datasets.imdb.get_word_index()
word_index = {k: (v + 3) for k, v in word_index.items()}
word_index["<PAD>"] = 0
word_index["<START>"] = 1
word_index["<UNK>"] = 2 # unknown
word_index["<UNUSED>"] = 3
reverse_word_index = dict([
(value, key) for (key, value) in word_index.items()
])
return word_index, reverse_word_index


def decode_review(text):
return ' '.join([reverse_word_index.get(i, '?') for i in text])


word_index, reverse_word_index = make_word_index()


def get_data():
(train_data, train_labels), (test_data,
test_labels) = tf.keras.datasets.imdb.load_data(
num_words=10000, path='imdb-0')
train_data = tf.keras.preprocessing.sequence.pad_sequences(
train_data, value=word_index["<PAD>"], padding='post', maxlen=256)
test_data = tf.keras.preprocessing.sequence.pad_sequences(
test_data, value=word_index["<PAD>"], padding='post', maxlen=256)
x_val = train_data[:1024]
x_train = train_data[:1024]
y_val = train_labels[:1024]
y_train = train_labels[:1024]
return x_train, y_train, x_val, y_val


x_train, y_train, x_val, y_val = get_data()


def input_fn_train():
dataset = tf.data.Dataset.from_tensor_slices(({'x': x_train}, y_train))
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
return dataset


def input_fn_val():
dataset = tf.data.Dataset.from_tensor_slices(({'x': x_val}, y_val))
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
return dataset


def model_fn(features, labels, mode):
x = features['x']
x = tf.reshape(x, [-1])
uniqx, uniqxidx = tf.unique(x)
w = tfra.dynamic_embedding.get_variable(
name='w',
initializer=tf.random_normal_initializer(0, 0.5),
dim=embed_size,
bp_v2=True, # this is the only thing you need to do to enable bpv2
key_dtype=tf.int32)
uniqe = tfra.dynamic_embedding.embedding_lookup(params=w, ids=uniqx, name='a')
e = tf.gather(uniqe, uniqxidx)
e = tf.reshape(e, [-1, 256, embed_size])

embmean = tf.reduce_mean(e, axis=1)
fc1 = tf.layers.dense(embmean, 16, activation=tf.nn.relu)
logits = tf.layers.dense(fc1, 2, activation=None)
predictions = {
"classes": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.softmax(logits, name="softmax_tensor"),
}

y = tf.one_hot(tf.cast(labels, tf.int32), 2, 1, 0)
loss = tf.losses.softmax_cross_entropy(y, logits)
if mode == tf.estimator.ModeKeys.TRAIN:
opt = tf.compat.v1.train.AdamOptimizer(0.01)
opt = tfra.dynamic_embedding.DynamicEmbeddingOptimizer(opt)
global_step = tf.compat.v1.train.get_or_create_global_step()
with tf.compat.v1.control_dependencies([
tf.print('step', global_step, 'loss', loss),
]):
train_op = opt.minimize(loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
else:
eval_metric_ops = {
"accuracy":
tf.metrics.accuracy(labels=labels,
predictions=predictions["classes"])
}
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
eval_metric_ops=eval_metric_ops)


config = tf.estimator.RunConfig(save_checkpoints_steps=None,
save_checkpoints_secs=tf.int64.max,
model_dir=None,
log_step_count_steps=1)
classifier = tf.estimator.Estimator(model_fn=model_fn, config=config)

tf.estimator.train_and_evaluate(
classifier,
train_spec=tf.estimator.TrainSpec(input_fn=input_fn_train, max_steps=3000),
eval_spec=tf.estimator.EvalSpec(input_fn=input_fn_val, steps=1000))
3 changes: 3 additions & 0 deletions demo/bpv2/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env bash

python main.py
6 changes: 5 additions & 1 deletion docs/api_docs/tfra/dynamic_embedding/Variable.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ __init__(
trainable=True,
checkpoint=True,
init_size=0,
restrict_policy=None
restrict_policy=None,
bp_v2=False,
)
```

Expand Down Expand Up @@ -105,6 +106,9 @@ def default_partition_fn(keys, shard_num):
size of variable. If in training program, the variable is updated by
optimizer, then the sparse slot variables in optimizer are also be
restricted.
* <b>`bp_v2`</b>:update parameters by *updating* instead of *setting*, which solves
the race condition problem among workers during backpropagation in large-scale
distributed asynchronous training.


#### Returns:
Expand Down
6 changes: 5 additions & 1 deletion docs/api_docs/tfra/dynamic_embedding/get_variable.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ tfra.dynamic_embedding.get_variable(
trainable=True,
checkpoint=True,
init_size=0,
restrict_policy=None
restrict_policy=None,
bp_v2=False,
)
```

Expand Down Expand Up @@ -80,6 +81,9 @@ def default_partition_fn(keys, shard_num):
size of variable. If in training program, the variable is updated by
optimizer, then the sparse slot variables in optimizer are also be
restricted.
* <b>`bp_v2`</b>:update parameters by *updating* instead of *setting*, which solves
the race condition problem among workers during backpropagation in large-scale
distributed asynchronous training.


#### Returns:
Expand Down
Loading

0 comments on commit c8b84e4

Please sign in to comment.