diff --git a/mmcv/cnn/__init__.py b/mmcv/cnn/__init__.py index 71d2b69357..f7522fa784 100644 --- a/mmcv/cnn/__init__.py +++ b/mmcv/cnn/__init__.py @@ -15,25 +15,27 @@ # yapf: enable from .resnet import ResNet, make_res_layer from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit, - NormalInit, PretrainedInit, UniformInit, XavierInit, - bias_init_with_prob, caffe2_xavier_init, constant_init, - fuse_conv_bn, get_model_complexity_info, initialize, - kaiming_init, normal_init, uniform_init, xavier_init) + NormalInit, PretrainedInit, TruncNormalInit, UniformInit, + XavierInit, bias_init_with_prob, caffe2_xavier_init, + constant_init, fuse_conv_bn, get_model_complexity_info, + initialize, kaiming_init, normal_init, trunc_normal_init, + uniform_init, xavier_init) from .vgg import VGG, make_vgg_layer __all__ = [ 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer', - 'constant_init', 'xavier_init', 'normal_init', 'uniform_init', - 'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule', - 'build_activation_layer', 'build_conv_layer', 'build_norm_layer', - 'build_padding_layer', 'build_upsample_layer', 'build_plugin_layer', - 'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock', - 'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS', - 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', - 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d', - 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule', - 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', - 'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit', - 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', + 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init', + 'uniform_init', 'kaiming_init', 'caffe2_xavier_init', + 'bias_init_with_prob', 'ConvModule', 'build_activation_layer', + 'build_conv_layer', 'build_norm_layer', 'build_padding_layer', + 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d', + 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish', + 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', + 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', + 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d', + 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d', + 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', + 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit', + 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg' ] diff --git a/mmcv/cnn/utils/__init__.py b/mmcv/cnn/utils/__init__.py index 18efa4135f..c8a4bd51f8 100644 --- a/mmcv/cnn/utils/__init__.py +++ b/mmcv/cnn/utils/__init__.py @@ -2,15 +2,17 @@ from .flops_counter import get_model_complexity_info from .fuse_conv_bn import fuse_conv_bn from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit, - KaimingInit, NormalInit, PretrainedInit, UniformInit, - XavierInit, bias_init_with_prob, caffe2_xavier_init, + KaimingInit, NormalInit, PretrainedInit, + TruncNormalInit, UniformInit, XavierInit, + bias_init_with_prob, caffe2_xavier_init, constant_init, initialize, kaiming_init, normal_init, - uniform_init, xavier_init) + trunc_normal_init, uniform_init, xavier_init) __all__ = [ 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init', - 'constant_init', 'kaiming_init', 'normal_init', 'uniform_init', - 'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS', - 'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', - 'PretrainedInit', 'Caffe2XavierInit' + 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init', + 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize', + 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit', + 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', + 'Caffe2XavierInit' ] diff --git a/mmcv/cnn/utils/weight_init.py b/mmcv/cnn/utils/weight_init.py index 6de857e73f..a05d774444 100644 --- a/mmcv/cnn/utils/weight_init.py +++ b/mmcv/cnn/utils/weight_init.py @@ -1,9 +1,12 @@ # Copyright (c) Open-MMLab. All rights reserved. import copy +import math import warnings import numpy as np +import torch import torch.nn as nn +from torch import Tensor from mmcv.utils import Registry, build_from_cfg, get_logger, print_log @@ -35,6 +38,18 @@ def normal_init(module, mean=0, std=1, bias=0): nn.init.constant_(module.bias, bias) +def trunc_normal_init(module: nn.Module, + mean: float = 0, + std: float = 1, + a: float = -2, + b: float = 2, + bias: float = 0) -> None: + if hasattr(module, 'weight') and module.weight is not None: + trunc_normal_(module.weight, mean, std, a, b) # type: ignore + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) # type: ignore + + def uniform_init(module, a=0, b=1, bias=0): if hasattr(module, 'weight') and module.weight is not None: nn.init.uniform_(module.weight, a, b) @@ -211,6 +226,55 @@ def init(m): module.apply(init) +@INITIALIZERS.register_module(name='TruncNormal') +class TruncNormalInit(BaseInit): + r"""Initialize module parameters with the values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values + outside :math:`[a, b]`. + + Args: + mean (float): the mean of the normal distribution. Defaults to 0. + std (float): the standard deviation of the normal distribution. + Defaults to 1. + a (float): The minimum cutoff value. + b ( float): The maximum cutoff value. + bias (float): the value to fill the bias or define + initialization type for bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + + """ + + def __init__(self, + mean: float = 0, + std: float = 1, + a: float = -2, + b: float = 2, + **kwargs) -> None: + super().__init__(**kwargs) + self.mean = mean + self.std = std + self.a = a + self.b = b + + def __call__(self, module: nn.Module) -> None: + + def init(m): + if self.wholemodule: + trunc_normal_init(m, self.mean, self.std, self.a, self.b, + self.bias) + else: + layername = m.__class__.__name__ + for layer_ in self.layer: + if layername == layer_: + trunc_normal_init(m, self.mean, self.std, self.a, + self.b, self.bias) + + module.apply(init) + + @INITIALIZERS.register_module(name='Uniform') class UniformInit(BaseInit): r"""Initialize module parameters with values drawn from the uniform @@ -468,3 +532,68 @@ def initialize(module, init_cfg): else: # All attributes in module have same initialization. pass + + +def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float, + b: float) -> Tensor: + # Method based on + # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + # Modified from + # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + lower = norm_cdf((a - mean) / std) + upper = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [lower, upper], then translate + # to [2lower-1, 2upper-1]. + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor: Tensor, + mean: float = 0., + std: float = 1., + a: float = -2., + b: float = 2.) -> Tensor: + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Modified from + https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py + + Args: + tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`. + mean (float): the mean of the normal distribution. + std (float): the standard deviation of the normal distribution. + a (float): the minimum cutoff value. + b (float): the maximum cutoff value. + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/mmcv/commit_id.py b/mmcv/commit_id.py new file mode 100644 index 0000000000..12b62a3b86 --- /dev/null +++ b/mmcv/commit_id.py @@ -0,0 +1 @@ +commit_id = '59b2b1c' diff --git a/requirements/test.txt b/requirements/test.txt index fe41ebe185..ab4ecbd5c1 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -5,4 +5,5 @@ onnxoptimizer onnxruntime==1.4.0 pytest PyTurboJPEG +scipy tiffile diff --git a/setup.cfg b/setup.cfg index 25825f09aa..fbd78ef0e6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools,logging,os,warnings,abc known_first_party = mmcv -known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,tensorrt,torch,torchvision,yaml,yapf +known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,scipy,tensorrt,torch,torchvision,yaml,yapf no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_cnn/test_weight_init.py b/tests/test_cnn/test_weight_init.py index 343079c45e..ea22843995 100644 --- a/tests/test_cnn/test_weight_init.py +++ b/tests/test_cnn/test_weight_init.py @@ -1,16 +1,18 @@ # Copyright (c) Open-MMLab. All rights reserved. +import random from tempfile import TemporaryDirectory import numpy as np import pytest import torch +from scipy import stats from torch import nn from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit, - PretrainedInit, UniformInit, XavierInit, + PretrainedInit, TruncNormalInit, UniformInit, XavierInit, bias_init_with_prob, caffe2_xavier_init, constant_init, - initialize, kaiming_init, normal_init, uniform_init, - xavier_init) + initialize, kaiming_init, normal_init, trunc_normal_init, + uniform_init, xavier_init) def test_constant_init(): @@ -47,6 +49,35 @@ def test_normal_init(): # TODO: sanity check distribution, e.g. mean, std +def test_trunc_normal_init(): + + def _random_float(a, b): + return (b - a) * random.random() + a + + def _is_trunc_normal(tensor, mean, std, a, b): + # scipy's trunc norm is suited for data drawn from N(0, 1), + # so we need to transform our data to test it using scipy. + z_samples = (tensor.view(-1) - mean) / std + z_samples = z_samples.tolist() + a0 = (a - mean) / std + b0 = (b - mean) / std + p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1] + return p_value > 0.0001 + + conv_module = nn.Conv2d(3, 16, 3) + mean = _random_float(-3, 3) + std = _random_float(.01, 1) + a = _random_float(mean - 2 * std, mean) + b = _random_float(mean, mean + 2 * std) + trunc_normal_init(conv_module, mean, std, a, b, bias=0.1) + assert _is_trunc_normal(conv_module.weight, mean, std, a, b) + assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) + + conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) + trunc_normal_init(conv_module_no_bias) + # TODO: sanity check distribution, e.g. mean, std + + def test_uniform_init(): conv_module = nn.Conv2d(3, 16, 3) uniform_init(conv_module, bias=0.1) @@ -168,6 +199,33 @@ def test_normalinit(): assert model[2].bias.allclose(torch.tensor(res)) +def test_truncnormalinit(): + """test TruncNormalInit class.""" + model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) + + func = TruncNormalInit( + mean=100, std=1e-5, bias=200, a=0, b=200, layer=['Conv2d', 'Linear']) + func(model) + assert model[0].weight.allclose(torch.tensor(100.)) + assert model[2].weight.allclose(torch.tensor(100.)) + assert model[0].bias.allclose(torch.tensor(200.)) + assert model[2].bias.allclose(torch.tensor(200.)) + + func = TruncNormalInit( + mean=300, + std=1e-5, + a=100, + b=400, + bias_prob=0.01, + layer=['Conv2d', 'Linear']) + res = bias_init_with_prob(0.01) + func(model) + assert model[0].weight.allclose(torch.tensor(300.)) + assert model[2].weight.allclose(torch.tensor(300.)) + assert model[0].bias.allclose(torch.tensor(res)) + assert model[2].bias.allclose(torch.tensor(res)) + + def test_uniforminit(): """"test UniformInit class.""" model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))