Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some dist-training robust cases into fluid benchmark test #11207

Merged
merged 16 commits into from
Jun 11, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion benchmark/fluid/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ Currently supported `--model` argument include:

* Run the following command to start a benchmark job locally:
```bash
python fluid_benchmark.py --model mnist --device GPU
python fluid_benchmark.py --model mnist --device GPU
```
You can choose to use GPU/CPU training. With GPU training, you can specify
`--gpus <gpu_num>` to run multi GPU training.
You can set gradient clipping. With gradient clipping, you can specify
`--gradient_clipping_method GlobalNorm` to clip the gradient with global norm.
You can set regularizer to optimizer. With regularization, you can specify
`--weight_decay_regularizer_method L1` to add regularizer to optimizer.
* Run distributed training with parameter servers:
* start parameter servers:
```bash
Expand Down
69 changes: 69 additions & 0 deletions benchmark/fluid/fluid_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,67 @@ def parse_args():
help='The model to run benchmark with.')
parser.add_argument(
'--batch_size', type=int, default=32, help='The minibatch size.')
# args related to learning rate
parser.add_argument(
'--learning_rate', type=float, default=0.001, help='The learning rate.')
parser.add_argument(
'--learning_rate_decay_method',
type=str,
default=None,
choices=['exponential', 'natural_exp', 'inverse_time'],
help='Learning rate decay method, can be exponential, natural_exp, inverse_time'
)
parser.add_argument(
'--learning_rate_decay_steps',
type=int,
default=100000,
help='Decay steps for learning rate decay method')
parser.add_argument(
'--learning_rate_decay_rate',
type=float,
default=0.5,
help='Decay rate for learning rate decay method')
# args related to regularization
parser.add_argument(
'--weight_decay_regularizer_method',
type=str,
default=None,
choices=['L1', 'L2'],
help='Weight decay regularizer method, can be L1, L2')
parser.add_argument(
'--weight_decay_regularizer_coeff',
type=float,
default=0.1,
help='Weight decay regularizer coeff, 0.1 for default')
# args related to gradient clipping
parser.add_argument(
'--gradient_clip_method',
type=str,
default=None,
choices=['Norm', 'GlobalNorm'],
help='Gradient clipping method, can be Norm, GlobalNorm')
parser.add_argument(
'--gradient_clip_norm',
type=float,
default=1.,
help='Gradient clipping norm, 1. for default')
# args related to error clipping
parser.add_argument(
'--error_clip_method',
type=str,
default=None,
choices=['Value'],
help='Error clipping method, can be Value')
parser.add_argument(
'--error_clip_min',
type=float,
default=1e-6,
help='Error clipping min value, 1e-6 for default')
parser.add_argument(
'--error_clip_max',
type=float,
default=2e-6,
help='Error clipping max value, 2e-6 for default')
# TODO(wuyi): add "--use_fake_data" option back.
parser.add_argument(
'--skip_batch_num',
Expand Down Expand Up @@ -108,6 +167,16 @@ def parse_args():
default='local',
choices=['local', 'pserver', 'nccl2'],
help='Choose parameter update method, can be local, pserver, nccl2.')
parser.add_argument(
'--no_split_var',
action='store_true',
default=False,
help='Whether split variables into blocks when update_method is pserver')
parser.add_argument(
'--async_mode',
action='store_true',
default=False,
help='Whether start pserver in async mode to support ASGD')
args = parser.parse_args()
return args

Expand Down
24 changes: 21 additions & 3 deletions benchmark/fluid/models/machine_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
import paddle.fluid.core as core
import paddle.fluid.framework as framework
from paddle.fluid.executor import Executor
from models.model_base import get_decay_learning_rate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_base is not uploaded?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review, I added the benchmark/fluid/models/model_base.py file in next commit

from models.model_base import get_regularization
from models.model_base import set_error_clip
from models.model_base import set_gradient_clip


def lstm_step(x_t, hidden_t_prev, cell_t_prev, size):
Expand All @@ -50,7 +54,7 @@ def linear(inputs):


def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
target_dict_dim, is_generating, beam_size, max_length):
target_dict_dim, is_generating, beam_size, max_length, args):
"""Construct a seq2seq network."""

def bi_lstm_encoder(input_seq, gate_size):
Expand Down Expand Up @@ -99,6 +103,8 @@ def bi_lstm_encoder(input_seq, gate_size):
size=decoder_size,
bias_attr=False,
act='tanh')
set_error_clip(args.error_clip_method, encoded_proj.name,
args.error_clip_min, args.error_clip_max)

def lstm_decoder_with_attention(target_embedding, encoder_vec, encoder_proj,
decoder_boot, decoder_size):
Expand Down Expand Up @@ -211,12 +217,24 @@ def get_model(args):
dict_size,
False,
beam_size=beam_size,
max_length=max_length)
max_length=max_length,
args=args)

# clone from default main program
inference_program = fluid.default_main_program().clone()

optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
# set gradient clip
set_gradient_clip(args.gradient_clip_method, args.gradient_clip_norm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way that we can disable these settings if the args is empty?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if clip_method in args is None, these settings will be disabled, and if user do NOT specify the args --gradient_clip_method, the args will be None in the case of default.

the code was like below

def set_gradient_clip(clip_method, clip_norm=1.):
    if not clip_method:
        return None


optimizer = fluid.optimizer.Adam(
learning_rate=get_decay_learning_rate(
decay_method=args.learning_rate_decay_method,
learning_rate=args.learning_rate,
decay_steps=args.learning_rate_decay_steps,
decay_rate=args.learning_rate_decay_rate),
regularization=get_regularization(
regularizer_method=args.weight_decay_regularizer_method,
regularizer_coeff=args.weight_decay_regularizer_coeff))

train_batch_generator = paddle.batch(
paddle.reader.shuffle(
Expand Down
27 changes: 23 additions & 4 deletions benchmark/fluid/models/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from models.model_base import get_decay_learning_rate
from models.model_base import get_regularization
from models.model_base import set_error_clip
from models.model_base import set_gradient_clip

SEED = 1
DTYPE = "float32"
Expand All @@ -32,7 +36,7 @@
# fluid.default_startup_program().random_seed = SEED


def cnn_model(data):
def cnn_model(data, args):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=data,
filter_size=5,
Expand All @@ -48,6 +52,9 @@ def cnn_model(data):
pool_stride=2,
act="relu")

set_error_clip(args.error_clip_method, conv_pool_1.name,
args.error_clip_min, args.error_clip_max)

# TODO(dzhwinter) : refine the initializer and random seed settting
SIZE = 10
input_shape = conv_pool_2.shape
Expand All @@ -73,7 +80,7 @@ def get_model(args):
places = fluid.layers.get_places(args.cpus)
pd = fluid.layers.ParallelDo(places)
with pd.do():
predict = cnn_model(pd.read_input(images))
predict = cnn_model(pd.read_input(images), args)
label = pd.read_input(label)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
Expand All @@ -87,7 +94,7 @@ def get_model(args):
batch_acc = fluid.layers.mean(batch_acc)
else:
# Train program
predict = cnn_model(images)
predict = cnn_model(images, args)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)

Expand All @@ -97,9 +104,21 @@ def get_model(args):
# inference program
inference_program = fluid.default_main_program().clone()

# set gradient clip
# set_gradient_clip(args.gradient_clip_method, args.gradient_clip_norm)

# Optimization
opt = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, beta1=0.9, beta2=0.999)
learning_rate=get_decay_learning_rate(
decay_method=args.learning_rate_decay_method,
learning_rate=0.001,
decay_steps=args.learning_rate_decay_steps,
decay_rate=args.learning_rate_decay_rate),
regularization=get_regularization(
regularizer_method=args.weight_decay_regularizer_method,
regularizer_coeff=args.weight_decay_regularizer_coeff),
beta1=0.9,
beta2=0.999)

# Reader
train_reader = paddle.batch(
Expand Down
86 changes: 86 additions & 0 deletions benchmark/fluid/models/model_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import argparse

import paddle.fluid as fluid
from paddle.fluid.regularizer import L1DecayRegularizer
from paddle.fluid.regularizer import L2DecayRegularizer
from paddle.fluid.clip import GradientClipByNorm
from paddle.fluid.clip import GradientClipByGlobalNorm
from paddle.fluid.clip import ErrorClipByValue

__all__ = [
'get_decay_learning_rate',
'get_regularization',
'set_error_clip',
'set_gradient_clip',
]


def get_decay_learning_rate(decay_method,
learning_rate=0.001,
decay_steps=100000,
decay_rate=0.5,
staircase=True):
if not decay_method:
return learning_rate
else:
decay_op = getattr(fluid.layers, "%s_decay" % decay_method)
return decay_op(
learning_rate=learning_rate,
decay_steps=decay_steps,
decay_rate=decay_rate)


def get_regularization(regularizer_method, regularizer_coeff=0.1):
if not regularizer_method:
return None
else:
RegularizerClazz = globals()["%sDecayRegularizer" % regularizer_method]
regularizer = RegularizerClazz(regularization_coeff=regularizer_coeff)
return regularizer


def set_error_clip(clip_method,
layer_name,
clip_min=-1e-6,
clip_max=2e-6,
program=None):
assert clip_min < clip_max
if not clip_method:
return None
else:
ClipClazz = globals()["ErrorClipBy%s" % clip_method]
if not program:
prog = fluid.default_main_program()
else:
prog = program
prog.block(0).var(layer_name).set_error_clip(
ClipClazz(
max=clip_max, min=clip_min))


def set_gradient_clip(clip_method, clip_norm=1.):
if not clip_method:
return None
else:
ClipClazz = globals()["GradientClipBy%s" % clip_method]
fluid.clip.set_gradient_clip(ClipClazz(clip_norm=clip_norm))
return clip_method
30 changes: 25 additions & 5 deletions benchmark/fluid/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.profiler as profiler
from models.model_base import get_decay_learning_rate
from models.model_base import get_regularization
from models.model_base import set_error_clip
from models.model_base import set_gradient_clip


def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
Expand Down Expand Up @@ -70,7 +74,7 @@ def layer_warp(block_func, input, ch_out, count, stride):
return res_out


def resnet_imagenet(input, class_dim, depth=50, data_format='NCHW'):
def resnet_imagenet(input, class_dim, args, depth=50, data_format='NCHW'):

cfg = {
18: ([2, 2, 2, 1], basicblock),
Expand All @@ -94,10 +98,12 @@ def resnet_imagenet(input, class_dim, depth=50, data_format='NCHW'):
pool_stride=1,
global_pooling=True)
out = fluid.layers.fc(input=pool2, size=class_dim, act='softmax')
set_error_clip(args.error_clip_method, out.name, args.error_clip_min,
args.error_clip_max)
return out


def resnet_cifar10(input, class_dim, depth=32, data_format='NCHW'):
def resnet_cifar10(input, class_dim, args, depth=32, data_format='NCHW'):
assert (depth - 2) % 6 == 0

n = (depth - 2) // 6
Expand All @@ -110,6 +116,8 @@ def resnet_cifar10(input, class_dim, depth=32, data_format='NCHW'):
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
out = fluid.layers.fc(input=pool, size=class_dim, act='softmax')
set_error_clip(args.error_clip_method, out.name, args.error_clip_min,
args.error_clip_max)
return out


Expand Down Expand Up @@ -137,7 +145,7 @@ def get_model(args):
places = fluid.layers.get_places(args.cpus)
pd = fluid.layers.ParallelDo(places)
with pd.do():
predict = model(pd.read_input(input), class_dim)
predict = model(pd.read_input(input), class_dim, args=args)
label = pd.read_input(label)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
Expand All @@ -150,7 +158,7 @@ def get_model(args):
avg_cost = fluid.layers.mean(avg_cost)
batch_acc = fluid.layers.mean(batch_acc)
else:
predict = model(input, class_dim)
predict = model(input, class_dim, args=args)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
batch_acc = fluid.layers.accuracy(input=predict, label=label)
Expand All @@ -160,7 +168,19 @@ def get_model(args):
inference_program = fluid.io.get_inference_program(
target_vars=[batch_acc])

optimizer = fluid.optimizer.Momentum(learning_rate=0.01, momentum=0.9)
# set gradient clip
set_gradient_clip(args.gradient_clip_method, args.gradient_clip_norm)

optimizer = fluid.optimizer.Momentum(
learning_rate=get_decay_learning_rate(
decay_method=args.learning_rate_decay_method,
learning_rate=0.01,
decay_steps=args.learning_rate_decay_steps,
decay_rate=args.learning_rate_decay_rate),
regularization=get_regularization(
regularizer_method=args.weight_decay_regularizer_method,
regularizer_coeff=args.weight_decay_regularizer_coeff),
momentum=0.9)

train_reader = paddle.batch(
paddle.reader.shuffle(
Expand Down
Loading