diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 7412eede118d1..50554aaf9a8ed 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -56,6 +56,9 @@ std::map> op_outs_map = { {"batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, + {"sync_batch_norm", + {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", + "ReserveSpace"}}, }; // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are @@ -75,6 +78,7 @@ std::map> op_passing_outs_map = { {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"momentum", {"ParamOut", "VelocityOut"}}, {"batch_norm", {"MeanOut", "VarianceOut"}}, + {"sync_batch_norm", {"MeanOut", "VarianceOut"}}, {"accuracy", {"Correct", "Total"}}, {"fill_constant", {"Out"}}, {"matmul", {"Out"}}, diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index cc2b746b0c1e9..3205c7921d246 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -35,7 +35,7 @@ 'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Dropout', 'Embedding', 'GRUUnit', 'InstanceNorm', 'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct', 'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm', - 'SpectralNorm', 'TreeConv' + 'SpectralNorm', 'TreeConv', 'SyncBatchNorm' ] @@ -3182,3 +3182,216 @@ def forward(self, nodes_vector, edge_set): else: pre_activation = out return self._helper.append_activation(pre_activation, act=self._act) + + +class SyncBatchNorm(layers.Layer): + """ + :alias_main: paddle.nn.SyncBatchNorm + :alias: paddle.nn.SyncBatchNorm,paddle.nn.layer.SyncBatchNorm,paddle.nn.layer.norm.SyncBatchNorm + :old_api: paddle.fluid.dygraph.SyncBatchNorm + + This interface is used to construct a callable object of the ``SyncBatchNorm`` class. + For more details, refer to code examples. + It implements the function of the Batch Normalization Layer and can be used + as a normalizer function for conv2d and fully connected operations. + The data is normalized by the mean and variance of the channel based on all mini-batches + of the same process groups. + Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `_ + for more details. + + When model in train mode, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are the statistics of all mini-batches in the same process groups. + Calculated as follows: + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ + \ mini-batch\ mean \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ + + - :math:`x` : mini-batch data + - :math:`m` : the size of the mini-batch data + + When model in eval mode, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are global or running statistics (moving_mean and moving_variance). + It usually got from the pre-trained model. Calculated as follows: + + .. math:: + moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\ + moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\ + + The normalization function formula is as follows: + + .. math:: + + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + - :math:`\\epsilon` : add a smaller value to the variance to prevent division by zero + - :math:`\\gamma` : trainable proportional parameter + - :math:`\\beta` : trainable deviation parameter + + **Note**: + moving mean and moving variance will be calculate whether `track_running_stats` is set to `True` + or `False`, we will fix it in the next time. + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + eps(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + scale_attr(ParamAttr, optional): The parameter attribute for Parameter `scale` + of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. Default: None. + bias_attr(ParamAttr, optional): The parameter attribute for the bias of batch_norm. + If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + track_running_stats(bool, optional): Whether to compute global stats, which including running mean and + running variance. Default: True. + + Returns: + None + + Examples: + .. code-block:: python + + from paddle.fluid.dygraph import to_variable + import paddle.nn as nn + import numpy as np + + x = np.random.random(size=(3, 10, 3, 7)).astype('float32') + with fluid.dygraph.guard(): + x = to_variable(x) + sync_batch_norm = nn.SyncBatchNorm(10) + hidden1 = sync_batch_norm(x) + """ + + def __init__(self, + num_features, + eps=1e-05, + momentum=0.9, + track_running_stats=True, + scale_attr=None, + bias_attr=None, + name=None): + super(SyncBatchNorm, self).__init__() + self._scale_attr = scale_attr + self._bias_attr = bias_attr + + assert bias_attr is not False, "bias_attr should not be False in batch_norm." + + self._dtype = "float32" + + param_shape = [num_features] + + # create parameter + self.weight = self.create_parameter( + attr=self._scale_attr, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(1.0)) + self.weight.stop_gradient = self._scale_attr != None and self._scale_attr.learning_rate == 0. + + self.bias = self.create_parameter( + attr=self._bias_attr, + shape=param_shape, + dtype=self._dtype, + is_bias=True) + self.bias.stop_gradient = self._scale_attr != None and self._scale_attr.learning_rate == 0. + + self._mean = self.create_parameter( + attr=ParamAttr( + name=None, + initializer=Constant(0.0), + trainable=False, + do_model_average=False), + shape=param_shape, + dtype=self._dtype) + self._mean.stop_gradient = True + + self._variance = self.create_parameter( + attr=ParamAttr( + name=None, + initializer=Constant(1.0), + trainable=False, + do_model_average=False), + shape=param_shape, + dtype=self._dtype) + self._variance.stop_gradient = True + + self._data_layout = 'NCHW' + self._momentum = momentum + self._epsilon = epsilon + self._track_running_stats = track_running_stats + + def forward(self, input): + # create output + # mean and mean_out share the same memory + mean_out = self._mean + # variance and variance out share the same memory + variance_out = self._variance + + ### train mode: use mini-batch stats, eval mode: use global stats + if self.training: + use_global_stats = False + trainable_statistics = False + else: + use_global_stats = True + trainable_statistics = False + + if in_dygraph_mode(): + attrs = ("momentum", self._momentum, "epsilon", self._epsilon, + "is_test", not self.training, "data_layout", + self._data_layout, "use_mkldnn", False, "fuse_with_relu", + False, "use_global_stats", use_global_stats, + 'trainable_statistics', trainable_statistics) + sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm( + input, self.weight, self.bias, self._mean, self._variance, + mean_out, variance_out, *attrs) + + return sync_batch_norm_out + + check_variable_and_dtype(input, 'input', + ['float16', 'float32', 'float64'], 'BatchNorm') + + attrs = { + "momentum": self._momentum, + "epsilon": self._epsilon, + "is_test": not self.training, + "data_layout": self._data_layout, + "use_mkldnn": False, + "fuse_with_relu": False, + "use_global_stats": use_global_stats, + "trainable_statistics": trainable_statistics, + } + + inputs = { + "X": [input], + "Scale": [self.weight], + "Bias": [self.bias], + "Mean": [self._mean], + "Variance": [self._variance] + } + + saved_mean = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + saved_variance = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + sync_batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference( + self._dtype) + + outputs = { + "Y": [sync_batch_norm_out], + "MeanOut": [mean_out], + "VarianceOut": [variance_out], + "SavedMean": [saved_mean], + "SavedVariance": [saved_variance] + } + + self._helper.append_op( + type="sync_batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) + return sync_batch_norm_out diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py new file mode 100644 index 0000000000000..96be49e277deb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py @@ -0,0 +1,91 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import contextlib +import unittest +import numpy as np +import six +import pickle + +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer +#from paddle.nn import Conv2D, Pool2D, Linear, SyncBatchNorm +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, SyncBatchNorm +from paddle.fluid.dygraph.base import to_variable + +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase + + +class TestLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + super(TestLayer, self).__init__() + + self._conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + + self._sync_batch_norm = SyncBatchNorm(num_filters) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._sync_batch_norm(y) + + return y + + +class TestSyncBatchNorm(TestParallelDyGraphRunnerBase): + def get_model(self): + model = TestLayer(3, 64, 7) + train_reader = paddle.batch( + paddle.dataset.flowers.test(use_xmap=False), + batch_size=32, + drop_last=True) + opt = fluid.optimizer.Adam( + learning_rate=1e-3, parameter_list=model.parameters()) + return model, train_reader, opt + + def run_one_loop(self, model, opt, data): + batch_size = len(data) + dy_x_data = np.array([x[0].reshape(3, 224, 224) + for x in data]).astype('float32') + img = to_variable(dy_x_data) + img.stop_gradient = False + + out = model(img) + + out = fluid.layers.mean(out) + + return out + + +if __name__ == "__main__": + runtime_main(TestSyncBatchNorm) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py new file mode 100644 index 0000000000000..7d48750b88eb8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py @@ -0,0 +1,40 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase +import paddle.fluid as fluid + +import os +flag_name = os.path.splitext(__file__)[0] + + +class TestParallelDygraphMnist(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + + def test_mnist(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sync_batch_norm.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index e074ca66bb1d3..c974fc1414512 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -85,6 +85,7 @@ from .layer.loss import NLLLoss #DEFINE_ALIAS from .layer.loss import BCELoss #DEFINE_ALIAS from .layer.norm import BatchNorm #DEFINE_ALIAS +from .layer.norm import SyncBatchNorm #DEFINE_ALIAS from .layer.norm import GroupNorm #DEFINE_ALIAS from .layer.norm import LayerNorm #DEFINE_ALIAS from .layer.norm import SpectralNorm #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 4963ac360804f..674782d474817 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -61,6 +61,7 @@ from .loss import NLLLoss #DEFINE_ALIAS from .loss import BCELoss #DEFINE_ALIAS from .norm import BatchNorm #DEFINE_ALIAS +from .norm import SyncBatchNorm #DEFINE_ALIAS from .norm import GroupNorm #DEFINE_ALIAS from .norm import LayerNorm #DEFINE_ALIAS from .norm import SpectralNorm #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 1beba62c1809f..1d00f9c7b8b02 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -20,7 +20,9 @@ from ...fluid.dygraph import GroupNorm #DEFINE_ALIAS from ...fluid.dygraph import LayerNorm #DEFINE_ALIAS from ...fluid.dygraph import SpectralNorm #DEFINE_ALIAS +from ...fluid.dygraph import SyncBatchNorm #DEFINE_ALIAS __all__ = [ - 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm' + 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm', + 'SyncBatchNorm' ]