From de683692ad6a679c790c65143fe54664223c4075 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 19 Mar 2018 22:01:38 -0700 Subject: [PATCH 1/2] 2 device hack --- .../transformer/config.py | 4 +- .../transformer/model.py | 166 ++++++++++++------ .../transformer/train.py | 41 +++-- 3 files changed, 143 insertions(+), 68 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 091ea17529..737568ba35 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -1,10 +1,10 @@ class TrainTaskConfig(object): - use_gpu = False + use_gpu = True # the epoch number to train. pass_num = 2 # number of sequences contained in a mini-batch. - batch_size = 64 + batch_size = 32 # the hyper params for Adam optimizer. learning_rate = 0.001 diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 379a17221c..ef49a4a42e 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -1,3 +1,5 @@ +import sys + from functools import partial import numpy as np @@ -84,7 +86,7 @@ def __split_heads(x, n_head): hidden_size = x.shape[-1] # FIXME(guosheng): Decouple the program desc with batch_size. reshaped = layers.reshape( - x=x, shape=[batch_size, -1, n_head, hidden_size // n_head]) + x=x, shape=[batch_size / 2, -1, n_head, hidden_size // n_head]) # permuate the dimensions into: # [batch_size, n_head, max_sequence_len, hidden_size_per_head] @@ -104,7 +106,7 @@ def __combine_heads(x): return layers.reshape( x=trans_x, shape=map(int, - [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]])) + [batch_size / 2, -1, trans_x.shape[2] * trans_x.shape[3]])) def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): """ @@ -230,7 +232,8 @@ def prepare_encoder(src_word, enc_input = src_word_emb + src_pos_enc # FIXME(guosheng): Decouple the program desc with batch_size. - enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim]) + enc_input = layers.reshape(x=enc_input, + shape=[batch_size / 2, -1, src_emb_dim]) return layers.dropout( enc_input, dropout_prob=dropout, is_test=False) if dropout else enc_input @@ -442,56 +445,7 @@ def transformer( dtype="float32", append_batch_size=False) - enc_input = prepare_encoder( - src_word, - src_pos, - src_vocab_size, - d_model, - src_pad_idx, - max_length, - dropout_rate, ) - enc_output = encoder( - enc_input, - src_slf_attn_bias, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - dropout_rate, ) - - dec_input = prepare_decoder( - trg_word, - trg_pos, - trg_vocab_size, - d_model, - trg_pad_idx, - max_length, - dropout_rate, ) - dec_output = decoder( - dec_input, - enc_output, - trg_slf_attn_bias, - trg_src_attn_bias, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - dropout_rate, ) - # TODO(guosheng): Share the weight matrix between the embedding layers and - # the pre-softmax linear transformation. - predict = layers.reshape( - x=layers.fc(input=dec_output, - size=trg_vocab_size, - param_attr=fluid.initializer.Xavier(uniform=False), - bias_attr=False, - num_flatten_dims=2), - shape=[-1, trg_vocab_size], - act="softmax") # The actual shape of gold in runtime is: # [batch_size * max_trg_length_in_a_batch, 1]. gold = layers.data( @@ -499,7 +453,7 @@ def transformer( shape=[batch_size * max_length, 1], dtype="int64", append_batch_size=False) - cost = layers.cross_entropy(input=predict, label=gold) + # The actual shape of weights in runtime is: # [batch_size * max_trg_length_in_a_batch, 1]. # Padding index do not contribute to the total loss. This Weight is used to @@ -509,5 +463,107 @@ def transformer( shape=[batch_size * max_length, 1], dtype="float32", append_batch_size=False) - weighted_cost = cost * weights - return layers.reduce_sum(weighted_cost) + + places = fluid.layers.get_places() + pd = fluid.layers.ParallelDo(places, use_nccl=False) + + src_word = fluid.layers.reshape(x=src_word, + shape=[batch_size, -1, 1]) + src_pos = fluid.layers.reshape(x=src_pos, + shape=[batch_size, -1, 1]) + trg_word = fluid.layers.reshape(x=trg_word, + shape=[batch_size, -1, 1]) + trg_pos = fluid.layers.reshape(x=trg_pos, + shape=[batch_size, -1, 1]) + gold = fluid.layers.reshape(x=gold, + shape=[batch_size, -1, 1]) + weights = fluid.layers.reshape(x=weights, + shape=[batch_size, -1, 1]) + + with pd.do(): + src_word = pd.read_input(src_word) + src_pos = pd.read_input(src_pos) + trg_word = pd.read_input(trg_word) + trg_pos = pd.read_input(trg_pos) + gold = pd.read_input(gold) + weights = pd.read_input(weights) + src_slf_attn_bias = pd.read_input(src_slf_attn_bias) + trg_slf_attn_bias = pd.read_input(trg_slf_attn_bias) + trg_src_attn_bias = pd.read_input(trg_src_attn_bias) + + src_word = fluid.layers.reshape( + x=src_word, shape=[-1, 1]) + src_word.stop_gradient = True + src_pos = fluid.layers.reshape( + x=src_pos, shape=[-1, 1]) + src_pos.stop_gradient = True + trg_word = fluid.layers.reshape( + x=trg_word, shape=[-1, 1]) + trg_word.stop_gradient = True + trg_pos = fluid.layers.reshape( + x=trg_pos, shape=[-1, 1]) + trg_pos.stop_gradient = True + gold = fluid.layers.reshape( + x=gold, shape=[-1, 1]) + gold.stop_gradient = True + weights = fluid.layers.reshape( + x=weights, shape=[-1, 1]) + weights.stop_gradient = True + + enc_input = prepare_encoder( + src_word, + src_pos, + src_vocab_size, + d_model, + src_pad_idx, + max_length, + dropout_rate, ) + enc_output = encoder( + enc_input, + src_slf_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, ) + + dec_input = prepare_decoder( + trg_word, + trg_pos, + trg_vocab_size, + d_model, + trg_pad_idx, + max_length, + dropout_rate, ) + dec_output = decoder( + dec_input, + enc_output, + trg_slf_attn_bias, + trg_src_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, ) + + # TODO(guosheng): Share the weight matrix between the embedding layers and + # the pre-softmax linear transformation. + predict = layers.reshape( + x=layers.fc(input=dec_output, + size=trg_vocab_size, + param_attr=fluid.initializer.Xavier(uniform=False), + bias_attr=False, + num_flatten_dims=2), + shape=[-1, trg_vocab_size], + act="softmax") + cost = layers.cross_entropy(input=predict, label=gold) + + weighted_cost = cost * weights + cost = layers.reduce_sum(weighted_cost) + pd.write_output(cost) + cost = pd() + return fluid.layers.mean(x=cost) diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 19835c486e..5ce85d94a9 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -1,7 +1,10 @@ import numpy as np +import sys +import time import paddle.v2 as paddle import paddle.fluid as fluid +import paddle.fluid.profiler as profiler from model import transformer, position_encoding_init from optim import LearningRateScheduler @@ -127,23 +130,39 @@ def main(): position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model), place) + def fn(pass_id, batch_id, data): + t1 = time.time() + data_input = prepare_batch_input( + data, input_data_names, ModelHyperParams.src_pad_idx, + ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, + ModelHyperParams.n_head, place) + lr_scheduler.update_learning_rate(data_input) + outs = exe.run(fluid.framework.default_main_program(), + feed=data_input, + fetch_list=[cost], + use_program_cache=True) + cost_val = np.array(outs[0]) + print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) + + " cost = " + str(cost_val)) + return time.time() - t1 + + total_time = 0.0 + count = 0 for pass_id in xrange(TrainTaskConfig.pass_num): for batch_id, data in enumerate(train_data()): # The current program desc is coupled with batch_size, thus all # mini-batches must have the same number of instances currently. if len(data) != TrainTaskConfig.batch_size: continue - data_input = prepare_batch_input( - data, input_data_names, ModelHyperParams.src_pad_idx, - ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, - ModelHyperParams.n_head, place) - lr_scheduler.update_learning_rate(data_input) - outs = exe.run(fluid.framework.default_main_program(), - feed=data_input, - fetch_list=[cost]) - cost_val = np.array(outs[0]) - print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) + - " cost = " + str(cost_val)) + if pass_id == 0 and batch_id >= 10 and batch_id < 12: + with profiler.profiler('All', 'total', '/tmp/transformer'): + duration = fn(pass_id, batch_id, data) + else: + duration = fn(pass_id, batch_id, data) + count += 1 + total_time += duration + print("avg: " + str(total_time / count) + " cur: " + str(duration)) + sys.stdout.flush() if __name__ == "__main__": From 88974072224cfb3bca4e6074646f08eefbf06438 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 20 Mar 2018 18:19:54 -0700 Subject: [PATCH 2/2] Better usage for multi-gpu Use must set num_gpus in config.py to the number of gpus available. --- .../neural_machine_translation/transformer/config.py | 3 +++ fluid/neural_machine_translation/transformer/model.py | 11 +++++++---- fluid/neural_machine_translation/transformer/train.py | 2 ++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 737568ba35..ecd436e03b 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -6,6 +6,9 @@ class TrainTaskConfig(object): # number of sequences contained in a mini-batch. batch_size = 32 + # number of gpu devices + num_gpus = 4 + # the hyper params for Adam optimizer. learning_rate = 0.001 beta1 = 0.9 diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index ef49a4a42e..288cd58fee 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -10,6 +10,7 @@ # FIXME(guosheng): Remove out the batch_size from the model. batch_size = TrainTaskConfig.batch_size +num_gpus = TrainTaskConfig.num_gpus def position_encoding_init(n_position, d_pos_vec): @@ -86,7 +87,8 @@ def __split_heads(x, n_head): hidden_size = x.shape[-1] # FIXME(guosheng): Decouple the program desc with batch_size. reshaped = layers.reshape( - x=x, shape=[batch_size / 2, -1, n_head, hidden_size // n_head]) + x=x, shape=[batch_size / num_gpus, -1, n_head, + hidden_size // n_head]) # permuate the dimensions into: # [batch_size, n_head, max_sequence_len, hidden_size_per_head] @@ -106,7 +108,8 @@ def __combine_heads(x): return layers.reshape( x=trans_x, shape=map(int, - [batch_size / 2, -1, trans_x.shape[2] * trans_x.shape[3]])) + [batch_size / num_gpus, -1, + trans_x.shape[2] * trans_x.shape[3]])) def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): """ @@ -233,7 +236,7 @@ def prepare_encoder(src_word, # FIXME(guosheng): Decouple the program desc with batch_size. enc_input = layers.reshape(x=enc_input, - shape=[batch_size / 2, -1, src_emb_dim]) + shape=[batch_size / num_gpus, -1, src_emb_dim]) return layers.dropout( enc_input, dropout_prob=dropout, is_test=False) if dropout else enc_input @@ -465,7 +468,7 @@ def transformer( append_batch_size=False) places = fluid.layers.get_places() - pd = fluid.layers.ParallelDo(places, use_nccl=False) + pd = fluid.layers.ParallelDo(places, use_nccl=True) src_word = fluid.layers.reshape(x=src_word, shape=[batch_size, -1, 1]) diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 5ce85d94a9..d4ddd78b77 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -146,6 +146,8 @@ def fn(pass_id, batch_id, data): " cost = " + str(cost_val)) return time.time() - t1 + # with open('/tmp/program', 'w') as f: + # f.write('%s' % fluid.framework.default_main_program()) total_time = 0.0 count = 0 for pass_id in xrange(TrainTaskConfig.pass_num):