Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add truncated normal weight init #935

Merged
merged 11 commits into from
May 23, 2021
34 changes: 18 additions & 16 deletions mmcv/cnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,27 @@
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
NormalInit, PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, initialize,
kaiming_init, normal_init, uniform_init, xavier_init)
NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
XavierInit, bias_init_with_prob, caffe2_xavier_init,
constant_init, fuse_conv_bn, get_model_complexity_info,
initialize, kaiming_init, normal_init, trunc_normal_init,
uniform_init, xavier_init)
from .vgg import VGG, make_vgg_layer

__all__ = [
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule',
'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
'build_padding_layer', 'build_upsample_layer', 'build_plugin_layer',
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock',
'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
]
16 changes: 9 additions & 7 deletions mmcv/cnn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit, UniformInit,
XavierInit, bias_init_with_prob, caffe2_xavier_init,
KaimingInit, NormalInit, PretrainedInit,
TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
uniform_init, xavier_init)
trunc_normal_init, uniform_init, xavier_init)

__all__ = [
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
'PretrainedInit', 'Caffe2XavierInit'
'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit'
]
129 changes: 129 additions & 0 deletions mmcv/cnn/utils/weight_init.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) Open-MMLab. All rights reserved.
import copy
import math
import warnings

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

from mmcv.utils import Registry, build_from_cfg, get_logger, print_log

Expand Down Expand Up @@ -35,6 +38,18 @@ def normal_init(module, mean=0, std=1, bias=0):
nn.init.constant_(module.bias, bias)


def trunc_normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) # type: ignore


def uniform_init(module, a=0, b=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b)
Expand Down Expand Up @@ -211,6 +226,55 @@ def init(m):
module.apply(init)


@INITIALIZERS.register_module(name='TruncNormal')
class TruncNormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
outside :math:`[a, b]`.

Args:
mean (float): the mean of the normal distribution. Defaults to 0.
std (float): the standard deviation of the normal distribution.
Defaults to 1.
a (float): The minimum cutoff value.
b ( float): The maximum cutoff value.
bias (float): the value to fill the bias or define
initialization type for bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.

"""

def __init__(self,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
**kwargs) -> None:
super().__init__(**kwargs)
self.mean = mean
self.std = std
self.a = a
self.b = b

def __call__(self, module: nn.Module) -> None:

def init(m):
if self.wholemodule:
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
trunc_normal_init(m, self.mean, self.std, self.a,
self.b, self.bias)

module.apply(init)


@INITIALIZERS.register_module(name='Uniform')
class UniformInit(BaseInit):
r"""Initialize module parameters with values drawn from the uniform
Expand Down Expand Up @@ -468,3 +532,68 @@ def initialize(module, init_cfg):
else:
# All attributes in module have same initialization.
pass


def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
b: float) -> Tensor:
# Method based on
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# Modified from
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.

if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)

with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)

# Uniformly fill tensor with values from [lower, upper], then translate
# to [2lower-1, 2upper-1].
tensor.uniform_(2 * lower - 1, 2 * upper - 1)

# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()

# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)

# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor


def trunc_normal_(tensor: Tensor,
mean: float = 0.,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.

Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py

Args:
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
mean (float): the mean of the normal distribution.
std (float): the standard deviation of the normal distribution.
a (float): the minimum cutoff value.
b (float): the maximum cutoff value.
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
1 change: 1 addition & 0 deletions mmcv/commit_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
commit_id = '59b2b1c'
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ onnxoptimizer
onnxruntime==1.4.0
pytest
PyTurboJPEG
scipy
tiffile
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools,logging,os,warnings,abc
known_first_party = mmcv
known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,tensorrt,torch,torchvision,yaml,yapf
known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,scipy,tensorrt,torch,torchvision,yaml,yapf
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
64 changes: 61 additions & 3 deletions tests/test_cnn/test_weight_init.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright (c) Open-MMLab. All rights reserved.
import random
from tempfile import TemporaryDirectory

import numpy as np
import pytest
import torch
from scipy import stats
from torch import nn

from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
PretrainedInit, TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
initialize, kaiming_init, normal_init, uniform_init,
xavier_init)
initialize, kaiming_init, normal_init, trunc_normal_init,
uniform_init, xavier_init)


def test_constant_init():
Expand Down Expand Up @@ -47,6 +49,35 @@ def test_normal_init():
# TODO: sanity check distribution, e.g. mean, std


def test_trunc_normal_init():

def _random_float(a, b):
return (b - a) * random.random() + a

def _is_trunc_normal(tensor, mean, std, a, b):
# scipy's trunc norm is suited for data drawn from N(0, 1),
# so we need to transform our data to test it using scipy.
z_samples = (tensor.view(-1) - mean) / std
z_samples = z_samples.tolist()
a0 = (a - mean) / std
b0 = (b - mean) / std
p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1]
return p_value > 0.0001

conv_module = nn.Conv2d(3, 16, 3)
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
mean = _random_float(-3, 3)
std = _random_float(.01, 1)
a = _random_float(mean - 2 * std, mean)
b = _random_float(mean, mean + 2 * std)
trunc_normal_init(conv_module, mean, std, a, b, bias=0.1)
assert _is_trunc_normal(conv_module.weight, mean, std, a, b)
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))

conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
trunc_normal_init(conv_module_no_bias)
# TODO: sanity check distribution, e.g. mean, std


def test_uniform_init():
conv_module = nn.Conv2d(3, 16, 3)
uniform_init(conv_module, bias=0.1)
Expand Down Expand Up @@ -168,6 +199,33 @@ def test_normalinit():
assert model[2].bias.allclose(torch.tensor(res))


def test_truncnormalinit():
"""test TruncNormalInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))

func = TruncNormalInit(
mean=100, std=1e-5, bias=200, a=0, b=200, layer=['Conv2d', 'Linear'])
func(model)
assert model[0].weight.allclose(torch.tensor(100.))
assert model[2].weight.allclose(torch.tensor(100.))
assert model[0].bias.allclose(torch.tensor(200.))
assert model[2].bias.allclose(torch.tensor(200.))

func = TruncNormalInit(
mean=300,
std=1e-5,
a=100,
b=400,
bias_prob=0.01,
layer=['Conv2d', 'Linear'])
res = bias_init_with_prob(0.01)
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))


def test_uniforminit():
""""test UniformInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
Expand Down