Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip the BatchNorm when feature only have 1 element. #11445

Closed
qingqing01 opened this issue Jun 13, 2018 · 1 comment
Closed

Skip the BatchNorm when feature only have 1 element. #11445

qingqing01 opened this issue Jun 13, 2018 · 1 comment
Assignees
Labels

Comments

@qingqing01
Copy link
Contributor

y = (x - mean(x)) / (std(x) + eps)

If x only have 1 element, mean(x) = x, std(x) = 0. The output will be entirely zero (ignoring the bias). The feature is no meaningless. In this case, we should not use feature-wise batch normalization.

@qingqing01
Copy link
Contributor Author

qingqing01 commented Jun 19, 2018

If there is only one element in norm dimension, for example, feature map is [1, 128, 1, 1] or [1, 32], the moving variance will be NaN. The following test can reproduce this problem:

import numpy as np

import paddle
import paddle.fluid as fluid

def main():
    epoc = 8
    dshape = [1, 128, 1, 1]
    data = fluid.layers.data(name='data', shape=[128, 1, 1], dtype='float32')
    conv = fluid.layers.conv2d(data, 128, 3, stride=1, padding=1)
    norm = fluid.layers.batch_norm(conv)

    place = fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

    for i in range(epoc):
        input = np.random.random(dshape).astype('float32')
        exe.run(fluid.default_main_program(), feed={'data': input})
        v = fluid.global_scope().find_var('batch_norm_0.w_2').get_tensor()
        v = np.array(v)
        print v
        # import math
        # for it in v.flatten().tolist():
        #    if np.isnan(it):
        #        print 'nan'
         #   if np.isinf(it):
         #       print 'inf'

if __name__ == '__main__':
    main()

Will print:

[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant