Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SyncBatchNorm #26032

Merged
merged 12 commits into from
Aug 19, 2020
4 changes: 4 additions & 0 deletions paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"batch_norm",
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
{"sync_batch_norm",
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
};

// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
Expand All @@ -75,6 +78,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"momentum", {"ParamOut", "VelocityOut"}},
{"batch_norm", {"MeanOut", "VarianceOut"}},
{"sync_batch_norm", {"MeanOut", "VarianceOut"}},
{"accuracy", {"Correct", "Total"}},
{"fill_constant", {"Out"}},
{"matmul", {"Out"}},
Expand Down
220 changes: 215 additions & 5 deletions python/paddle/fluid/dygraph/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Dropout', 'Embedding',
'GRUUnit', 'InstanceNorm', 'LayerNorm', 'NCE', 'PRelu',
'BilinearTensorProduct', 'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm',
'SpectralNorm', 'TreeConv', 'Flatten'
'SpectralNorm', 'TreeConv', 'Flatten', 'SyncBatchNorm'
]


Expand Down Expand Up @@ -3184,6 +3184,220 @@ def forward(self, nodes_vector, edge_set):
return self._helper.append_activation(pre_activation, act=self._act)


class SyncBatchNorm(layers.Layer):
"""
:alias_main: paddle.nn.SyncBatchNorm
:alias: paddle.nn.SyncBatchNorm,paddle.nn.layer.SyncBatchNorm,paddle.nn.layer.norm.SyncBatchNorm
:old_api: paddle.fluid.dygraph.SyncBatchNorm
ceci3 marked this conversation as resolved.
Show resolved Hide resolved

This interface is used to construct a callable object of the ``SyncBatchNorm`` class.
For more details, refer to code examples.
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
It implements the function of the Batch Normalization Layer and can be used
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
as a normalizer function for conv2d and fully connected operations.
The data is normalized by the mean and variance of the channel based on all mini-batches
of the same process groups.
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/pdf/1502.03167.pdf>`_
for more details.

When model in train mode, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are the statistics of all mini-batches in the same process groups.
Calculated as follows:

.. math::

\\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\
\ mini-batch\ mean \\\\
\\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\
\\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\

- :math:`x` : mini-batch data
- :math:`m` : the size of the mini-batch data

When model in eval mode, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are global or running statistics (moving_mean and moving_variance).
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
It usually got from the pre-trained model. Calculated as follows:
ceci3 marked this conversation as resolved.
Show resolved Hide resolved

.. math::
moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\
moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\

The normalization function formula is as follows:
ceci3 marked this conversation as resolved.
Show resolved Hide resolved

.. math::

\\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\
\\sigma_{\\beta}^{2} + \\eps}} \\qquad &//\ normalize \\\\
y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift

- :math:`\\eps` : add a smaller value to the variance to prevent division by zero
- :math:`\\gamma` : trainable proportional parameter
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
- :math:`\\beta` : trainable deviation parameter

**Note**:
moving mean and moving variance will be calculate whether `track_running_stats` is set to `True`
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
or `False`, we will fix it in the next time.

Parameters:
num_features(int): Indicate the number of channels of the input ``Tensor``.
eps(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
scale_attr(ParamAttr, optional): The parameter attribute for Parameter `scale`
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr(ParamAttr, optional): The parameter attribute for the bias of batch_norm.
If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
track_running_stats(bool, optional): Whether to compute global stats, which including running mean and
running variance. Default: True.

Returns:
None

Examples:
.. code-block:: python

import paddle.nn as nn
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np

x = np.random.random(size=(3, 10, 3, 7)).astype('float32')
with fluid.dygraph.guard():
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例可以直接用paddle2.0-alpha的api
with paddle.imperative.guard():

x = to_variable(x)
sync_batch_norm = nn.SyncBatchNorm(10)
hidden1 = sync_batch_norm(x)
willthefrog marked this conversation as resolved.
Show resolved Hide resolved
"""
ceci3 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self,
num_features,
eps=1e-05,
momentum=0.9,
track_running_stats=True,
scale_attr=None,
bias_attr=None,
name=None):
super(SyncBatchNorm, self).__init__()
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
self._scale_attr = scale_attr
self._bias_attr = bias_attr

assert bias_attr is not False, "bias_attr should not be False in batch_norm."
ceci3 marked this conversation as resolved.
Show resolved Hide resolved

self._dtype = "float32"
ceci3 marked this conversation as resolved.
Show resolved Hide resolved

param_shape = [num_features]

# create parameter
self.weight = self.create_parameter(
attr=self._scale_attr,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0))
self.weight.stop_gradient = self._scale_attr != None and self._scale_attr.learning_rate == 0.

self.bias = self.create_parameter(
attr=self._bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
self.bias.stop_gradient = self._scale_attr != None and self._scale_attr.learning_rate == 0.

self._mean = self.create_parameter(
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
attr=ParamAttr(
name=None,
initializer=Constant(0.0),
trainable=False,
do_model_average=False),
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
shape=param_shape,
dtype=self._dtype)
self._mean.stop_gradient = True

self._variance = self.create_parameter(
attr=ParamAttr(
name=None,
initializer=Constant(1.0),
trainable=False,
do_model_average=False),
shape=param_shape,
dtype=self._dtype)
self._variance.stop_gradient = True

self._data_layout = 'NCHW'
self._momentum = momentum
self._eps = eps
self._track_running_stats = track_running_stats

def forward(self, input):
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
# create output
# mean and mean_out share the same memory
mean_out = self._mean
# variance and variance out share the same memory
variance_out = self._variance

### train mode: use mini-batch stats, eval mode: use global stats
if self.training:
use_global_stats = False
trainable_statistics = False
else:
use_global_stats = True
trainable_statistics = False

if in_dygraph_mode():
willthefrog marked this conversation as resolved.
Show resolved Hide resolved
attrs = ("momentum", self._momentum, "epsilon", self._eps,
"is_test", not self.training, "data_layout",
self._data_layout, "use_mkldnn", False, "fuse_with_relu",
False, "use_global_stats", use_global_stats,
'trainable_statistics', trainable_statistics)
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
input, self.weight, self.bias, self._mean, self._variance,
mean_out, variance_out, *attrs)

return sync_batch_norm_out

check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], 'BatchNorm')

attrs = {
"momentum": self._momentum,
"epsilon": self._eps,
"is_test": not self.training,
"data_layout": self._data_layout,
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": use_global_stats,
"trainable_statistics": trainable_statistics,
}

inputs = {
"X": [input],
"Scale": [self.weight],
"Bias": [self.bias],
"Mean": [self._mean],
"Variance": [self._variance]
}

saved_mean = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
sync_batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference(
self._dtype)

outputs = {
"Y": [sync_batch_norm_out],
"MeanOut": [mean_out],
"VarianceOut": [variance_out],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance]
}

self._helper.append_op(
type="sync_batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
return sync_batch_norm_out


class Flatten(layers.Layer):
"""
:alias_main: paddle.nn.Flatten
Expand All @@ -3198,10 +3412,6 @@ class Flatten(layers.Layer):
start_axis(int): first dim to flatten (default = 1)
stop_axis(int): last dim to flatten (default = -1).

Returns:
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
None

Examples:

.. code-block:: python

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright (c) 2020

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import contextlib
import unittest
import numpy as np
import six
import pickle

import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.nn import Conv2D, Pool2D, Linear, SyncBatchNorm
from paddle.fluid.dygraph.base import to_variable

from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase


class TestLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(TestLayer, self).__init__()

self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)

self._sync_batch_norm = SyncBatchNorm(num_filters)

def forward(self, inputs):
y = self._conv(inputs)
y = self._sync_batch_norm(y)

return y


class TestSyncBatchNorm(TestParallelDyGraphRunnerBase):
def get_model(self):
model = TestLayer(3, 64, 7)
train_reader = paddle.batch(
paddle.dataset.flowers.test(use_xmap=False),
batch_size=32,
drop_last=True)
opt = fluid.optimizer.Adam(
learning_rate=1e-3, parameter_list=model.parameters())
return model, train_reader, opt

def run_one_loop(self, model, opt, data):
batch_size = len(data)
dy_x_data = np.array([x[0].reshape(3, 224, 224)
for x in data]).astype('float32')
img = to_variable(dy_x_data)
img.stop_gradient = False

out = model(img)

out = fluid.layers.mean(out)

return out


if __name__ == "__main__":
runtime_main(TestSyncBatchNorm)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
ceci3 marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
import paddle.fluid as fluid

import os
flag_name = os.path.splitext(__file__)[0]


class TestParallelDygraphMnist(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True

def test_mnist(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_sync_batch_norm.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from .layer.loss import NLLLoss #DEFINE_ALIAS
from .layer.loss import BCELoss #DEFINE_ALIAS
from .layer.norm import BatchNorm #DEFINE_ALIAS
from .layer.norm import SyncBatchNorm #DEFINE_ALIAS
from .layer.norm import GroupNorm #DEFINE_ALIAS
from .layer.norm import LayerNorm #DEFINE_ALIAS
from .layer.norm import SpectralNorm #DEFINE_ALIAS
Expand Down
Loading