Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi gpu #745

Open
wants to merge 2 commits into
base: dev-static
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions fluid/neural_machine_translation/transformer/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
class TrainTaskConfig(object):
use_gpu = False
use_gpu = True
# the epoch number to train.
pass_num = 2

# number of sequences contained in a mini-batch.
batch_size = 64
batch_size = 32

# number of gpu devices
num_gpus = 4

# the hyper params for Adam optimizer.
learning_rate = 0.001
Expand Down
169 changes: 114 additions & 55 deletions fluid/neural_machine_translation/transformer/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

from functools import partial
import numpy as np

Expand All @@ -8,6 +10,7 @@

# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
num_gpus = TrainTaskConfig.num_gpus


def position_encoding_init(n_position, d_pos_vec):
Expand Down Expand Up @@ -84,7 +87,8 @@ def __split_heads(x, n_head):
hidden_size = x.shape[-1]
# FIXME(guosheng): Decouple the program desc with batch_size.
reshaped = layers.reshape(
x=x, shape=[batch_size, -1, n_head, hidden_size // n_head])
x=x, shape=[batch_size / num_gpus, -1, n_head,
hidden_size // n_head])

# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
Expand All @@ -104,7 +108,8 @@ def __combine_heads(x):
return layers.reshape(
x=trans_x,
shape=map(int,
[batch_size, -1, trans_x.shape[2] * trans_x.shape[3]]))
[batch_size / num_gpus, -1,
trans_x.shape[2] * trans_x.shape[3]]))

def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
"""
Expand Down Expand Up @@ -230,7 +235,8 @@ def prepare_encoder(src_word,
enc_input = src_word_emb + src_pos_enc

# FIXME(guosheng): Decouple the program desc with batch_size.
enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim])
enc_input = layers.reshape(x=enc_input,
shape=[batch_size / num_gpus, -1, src_emb_dim])
return layers.dropout(
enc_input, dropout_prob=dropout,
is_test=False) if dropout else enc_input
Expand Down Expand Up @@ -442,64 +448,15 @@ def transformer(
dtype="float32",
append_batch_size=False)

enc_input = prepare_encoder(
src_word,
src_pos,
src_vocab_size,
d_model,
src_pad_idx,
max_length,
dropout_rate, )
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )

dec_input = prepare_decoder(
trg_word,
trg_pos,
trg_vocab_size,
d_model,
trg_pad_idx,
max_length,
dropout_rate, )
dec_output = decoder(
dec_input,
enc_output,
trg_slf_attn_bias,
trg_src_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )

# TODO(guosheng): Share the weight matrix between the embedding layers and
# the pre-softmax linear transformation.
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
param_attr=fluid.initializer.Xavier(uniform=False),
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax")
# The actual shape of gold in runtime is:
# [batch_size * max_trg_length_in_a_batch, 1].
gold = layers.data(
name=input_data_names[7],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
cost = layers.cross_entropy(input=predict, label=gold)

# The actual shape of weights in runtime is:
# [batch_size * max_trg_length_in_a_batch, 1].
# Padding index do not contribute to the total loss. This Weight is used to
Expand All @@ -509,5 +466,107 @@ def transformer(
shape=[batch_size * max_length, 1],
dtype="float32",
append_batch_size=False)
weighted_cost = cost * weights
return layers.reduce_sum(weighted_cost)

places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places, use_nccl=True)

src_word = fluid.layers.reshape(x=src_word,
shape=[batch_size, -1, 1])
src_pos = fluid.layers.reshape(x=src_pos,
shape=[batch_size, -1, 1])
trg_word = fluid.layers.reshape(x=trg_word,
shape=[batch_size, -1, 1])
trg_pos = fluid.layers.reshape(x=trg_pos,
shape=[batch_size, -1, 1])
gold = fluid.layers.reshape(x=gold,
shape=[batch_size, -1, 1])
weights = fluid.layers.reshape(x=weights,
shape=[batch_size, -1, 1])

with pd.do():
src_word = pd.read_input(src_word)
src_pos = pd.read_input(src_pos)
trg_word = pd.read_input(trg_word)
trg_pos = pd.read_input(trg_pos)
gold = pd.read_input(gold)
weights = pd.read_input(weights)
src_slf_attn_bias = pd.read_input(src_slf_attn_bias)
trg_slf_attn_bias = pd.read_input(trg_slf_attn_bias)
trg_src_attn_bias = pd.read_input(trg_src_attn_bias)

src_word = fluid.layers.reshape(
x=src_word, shape=[-1, 1])
src_word.stop_gradient = True
src_pos = fluid.layers.reshape(
x=src_pos, shape=[-1, 1])
src_pos.stop_gradient = True
trg_word = fluid.layers.reshape(
x=trg_word, shape=[-1, 1])
trg_word.stop_gradient = True
trg_pos = fluid.layers.reshape(
x=trg_pos, shape=[-1, 1])
trg_pos.stop_gradient = True
gold = fluid.layers.reshape(
x=gold, shape=[-1, 1])
gold.stop_gradient = True
weights = fluid.layers.reshape(
x=weights, shape=[-1, 1])
weights.stop_gradient = True

enc_input = prepare_encoder(
src_word,
src_pos,
src_vocab_size,
d_model,
src_pad_idx,
max_length,
dropout_rate, )
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )

dec_input = prepare_decoder(
trg_word,
trg_pos,
trg_vocab_size,
d_model,
trg_pad_idx,
max_length,
dropout_rate, )
dec_output = decoder(
dec_input,
enc_output,
trg_slf_attn_bias,
trg_src_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )

# TODO(guosheng): Share the weight matrix between the embedding layers and
# the pre-softmax linear transformation.
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
param_attr=fluid.initializer.Xavier(uniform=False),
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax")
cost = layers.cross_entropy(input=predict, label=gold)

weighted_cost = cost * weights
cost = layers.reduce_sum(weighted_cost)
pd.write_output(cost)
cost = pd()
return fluid.layers.mean(x=cost)
43 changes: 32 additions & 11 deletions fluid/neural_machine_translation/transformer/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import sys
import time

import paddle.v2 as paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler

from model import transformer, position_encoding_init
from optim import LearningRateScheduler
Expand Down Expand Up @@ -127,23 +130,41 @@ def main():
position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place)

def fn(pass_id, batch_id, data):
t1 = time.time()
data_input = prepare_batch_input(
data, input_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head, place)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
fetch_list=[cost],
use_program_cache=True)
cost_val = np.array(outs[0])
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
" cost = " + str(cost_val))
return time.time() - t1

# with open('/tmp/program', 'w') as f:
# f.write('%s' % fluid.framework.default_main_program())
total_time = 0.0
count = 0
for pass_id in xrange(TrainTaskConfig.pass_num):
for batch_id, data in enumerate(train_data()):
# The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently.
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, input_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head, place)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
fetch_list=[cost])
cost_val = np.array(outs[0])
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
" cost = " + str(cost_val))
if pass_id == 0 and batch_id >= 10 and batch_id < 12:
with profiler.profiler('All', 'total', '/tmp/transformer'):
duration = fn(pass_id, batch_id, data)
else:
duration = fn(pass_id, batch_id, data)
count += 1
total_time += duration
print("avg: " + str(total_time / count) + " cur: " + str(duration))
sys.stdout.flush()


if __name__ == "__main__":
Expand Down