diff --git a/mmcls/core/__init__.py b/mmcls/core/__init__.py index 3a1df4d03f7..1588502a337 100644 --- a/mmcls/core/__init__.py +++ b/mmcls/core/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .evaluation import * # noqa: F401, F403 from .fp16 import * # noqa: F401, F403 +from .optimizers import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 diff --git a/mmcls/core/optimizers/__init__.py b/mmcls/core/optimizers/__init__.py new file mode 100644 index 00000000000..937da98f54a --- /dev/null +++ b/mmcls/core/optimizers/__init__.py @@ -0,0 +1,5 @@ +from .lamb import Lamb + +__all__ = [ + 'Lamb', +] diff --git a/mmcls/core/optimizers/lamb.py b/mmcls/core/optimizers/lamb.py new file mode 100644 index 00000000000..490320e6f61 --- /dev/null +++ b/mmcls/core/optimizers/lamb.py @@ -0,0 +1,231 @@ +"""PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb. + +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/ +2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/ +LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/cybertronai/pytorch-lamb + +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb +is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or +cannot install/use APEX. + +In addition to some cleanup, this Lamb impl has been modified to support +PyTorch XLA and has been tested on TPU. + +Original copyrights for above sources are below. + +Modifications Copyright 2021 Ross Wightman +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math + +import torch +from mmcv.runner import OPTIMIZERS +from torch.optim import Optimizer + + +@OPTIMIZERS.register_module() +class Lamb(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer + from apex.optimizers.FusedLAMB + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/ + PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training + BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm + (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 + minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + grad_averaging=True, + max_grad_norm=1.0, + trust_clip=False, + always_adapt=False): + defaults = dict( + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm, + trust_clip=trust_clip, + always_adapt=always_adapt) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor( + 1.0, device=device + ) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + 'Lamb does not support sparse gradients, consider ' + 'SparseAdam instead.') + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars + # when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor( + self.defaults['max_grad_norm'], device=device) + clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor) + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or + # pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1**group['step'] + bias_correction2 = 1 - beta2**group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group['eps']) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group['weight_decay'] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on + # parameters that are + # excluded from weight decay, unless always_adapt == True, + # then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not + # working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) + + return loss diff --git a/tests/test_runtime/test_optimizer.py b/tests/test_runtime/test_optimizer.py new file mode 100644 index 00000000000..e501eea7baa --- /dev/null +++ b/tests/test_runtime/test_optimizer.py @@ -0,0 +1,308 @@ +import functools +from collections import OrderedDict +from copy import deepcopy +from typing import Iterable + +import torch +import torch.nn as nn +from mmcv.runner import build_optimizer +from mmcv.runner.optimizer.builder import OPTIMIZERS +from mmcv.utils.registry import build_from_cfg +from torch.autograd import Variable +from torch.optim.optimizer import Optimizer + +import mmcls.core # noqa: F401 + +base_lr = 0.01 +base_wd = 0.0001 + + +def assert_equal(x, y): + if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + torch.testing.assert_allclose(x, y.to(x.device)) + elif isinstance(x, OrderedDict) and isinstance(y, OrderedDict): + for x_value, y_value in zip(x.values(), y.values()): + assert_equal(x_value, y_value) + elif isinstance(x, dict) and isinstance(y, dict): + assert x.keys() == y.keys() + for key in x.keys(): + assert_equal(x[key], y[key]) + elif isinstance(x, str) and isinstance(y, str): + assert x == y + elif isinstance(x, Iterable) and isinstance(y, Iterable): + assert len(x) == len(y) + for x_item, y_item in zip(x, y): + assert_equal(x_item, y_item) + else: + assert x == y + + +class SubModel(nn.Module): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2) + self.gn = nn.GroupNorm(2, 2) + self.fc = nn.Linear(2, 2) + self.param1 = nn.Parameter(torch.ones(1)) + + def forward(self, x): + return x + + +class ExampleModel(nn.Module): + + def __init__(self): + super().__init__() + self.param1 = nn.Parameter(torch.ones(1)) + self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False) + self.conv2 = nn.Conv2d(4, 2, kernel_size=1) + self.bn = nn.BatchNorm2d(2) + self.sub = SubModel() + self.fc = nn.Linear(2, 1) + + def forward(self, x): + return x + + +def check_lamb_optimizer(optimizer, + model, + bias_lr_mult=1, + bias_decay_mult=1, + norm_decay_mult=1, + dwconv_decay_mult=1): + param_groups = optimizer.param_groups + assert isinstance(optimizer, Optimizer) + assert optimizer.defaults['lr'] == base_lr + assert optimizer.defaults['weight_decay'] == base_wd + model_parameters = list(model.parameters()) + assert len(param_groups) == len(model_parameters) + for i, param in enumerate(model_parameters): + param_group = param_groups[i] + assert torch.equal(param_group['params'][0], param) + # param1 + param1 = param_groups[0] + assert param1['lr'] == base_lr + assert param1['weight_decay'] == base_wd + # conv1.weight + conv1_weight = param_groups[1] + assert conv1_weight['lr'] == base_lr + assert conv1_weight['weight_decay'] == base_wd + # conv2.weight + conv2_weight = param_groups[2] + assert conv2_weight['lr'] == base_lr + assert conv2_weight['weight_decay'] == base_wd + # conv2.bias + conv2_bias = param_groups[3] + assert conv2_bias['lr'] == base_lr * bias_lr_mult + assert conv2_bias['weight_decay'] == base_wd * bias_decay_mult + # bn.weight + bn_weight = param_groups[4] + assert bn_weight['lr'] == base_lr + assert bn_weight['weight_decay'] == base_wd * norm_decay_mult + # bn.bias + bn_bias = param_groups[5] + assert bn_bias['lr'] == base_lr + assert bn_bias['weight_decay'] == base_wd * norm_decay_mult + # sub.param1 + sub_param1 = param_groups[6] + assert sub_param1['lr'] == base_lr + assert sub_param1['weight_decay'] == base_wd + # sub.conv1.weight + sub_conv1_weight = param_groups[7] + assert sub_conv1_weight['lr'] == base_lr + assert sub_conv1_weight['weight_decay'] == base_wd * dwconv_decay_mult + # sub.conv1.bias + sub_conv1_bias = param_groups[8] + assert sub_conv1_bias['lr'] == base_lr * bias_lr_mult + assert sub_conv1_bias['weight_decay'] == base_wd * dwconv_decay_mult + # sub.gn.weight + sub_gn_weight = param_groups[9] + assert sub_gn_weight['lr'] == base_lr + assert sub_gn_weight['weight_decay'] == base_wd * norm_decay_mult + # sub.gn.bias + sub_gn_bias = param_groups[10] + assert sub_gn_bias['lr'] == base_lr + assert sub_gn_bias['weight_decay'] == base_wd * norm_decay_mult + # sub.fc1.weight + sub_fc_weight = param_groups[11] + assert sub_fc_weight['lr'] == base_lr + assert sub_fc_weight['weight_decay'] == base_wd + # sub.fc1.bias + sub_fc_bias = param_groups[12] + assert sub_fc_bias['lr'] == base_lr * bias_lr_mult + assert sub_fc_bias['weight_decay'] == base_wd * bias_decay_mult + # fc1.weight + fc_weight = param_groups[13] + assert fc_weight['lr'] == base_lr + assert fc_weight['weight_decay'] == base_wd + # fc1.bias + fc_bias = param_groups[14] + assert fc_bias['lr'] == base_lr * bias_lr_mult + assert fc_bias['weight_decay'] == base_wd * bias_decay_mult + + +def _test_state_dict(weight, bias, input, constructor): + weight = Variable(weight, requires_grad=True) + bias = Variable(bias, requires_grad=True) + inputs = Variable(input) + + def fn_base(optimizer, weight, bias): + optimizer.zero_grad() + i = input_cuda if weight.is_cuda else inputs + loss = (weight.mv(i) + bias).pow(2).sum() + loss.backward() + return loss + + optimizer = constructor(weight, bias) + fn = functools.partial(fn_base, optimizer, weight, bias) + + # Prime the optimizer + for _ in range(20): + optimizer.step(fn) + # Clone the weights and construct new optimizer for them + weight_c = Variable(weight.data.clone(), requires_grad=True) + bias_c = Variable(bias.data.clone(), requires_grad=True) + optimizer_c = constructor(weight_c, bias_c) + fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c) + # Load state dict + state_dict = deepcopy(optimizer.state_dict()) + state_dict_c = deepcopy(optimizer.state_dict()) + optimizer_c.load_state_dict(state_dict_c) + # Run both optimizations in parallel + for _ in range(20): + optimizer.step(fn) + optimizer_c.step(fn_c) + assert_equal(weight, weight_c) + assert_equal(bias, bias_c) + # Make sure state dict wasn't modified + assert_equal(state_dict, state_dict_c) + # Make sure state dict is deterministic with equal + # but not identical parameters + # NOTE: The state_dict of optimizers in PyTorch 1.5 have random keys, + state_dict = deepcopy(optimizer.state_dict()) + state_dict_c = deepcopy(optimizer_c.state_dict()) + keys = state_dict['param_groups'][-1]['params'] + keys_c = state_dict_c['param_groups'][-1]['params'] + for key, key_c in zip(keys, keys_c): + assert_equal(optimizer.state_dict()['state'][key], + optimizer_c.state_dict()['state'][key_c]) + # Make sure repeated parameters have identical representation in state dict + optimizer_c.param_groups.extend(optimizer_c.param_groups) + assert_equal(optimizer_c.state_dict()['param_groups'][0], + optimizer_c.state_dict()['param_groups'][1]) + + # Check that state dict can be loaded even when we cast parameters + # to a different type and move to a different device. + if not torch.cuda.is_available(): + return + + input_cuda = Variable(inputs.data.float().cuda()) + weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True) + bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True) + optimizer_cuda = constructor(weight_cuda, bias_cuda) + fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, + bias_cuda) + + state_dict = deepcopy(optimizer.state_dict()) + state_dict_c = deepcopy(optimizer.state_dict()) + optimizer_cuda.load_state_dict(state_dict_c) + + # Make sure state dict wasn't modified + assert_equal(state_dict, state_dict_c) + + for _ in range(20): + optimizer.step(fn) + optimizer_cuda.step(fn_cuda) + assert_equal(weight, weight_cuda) + assert_equal(bias, bias_cuda) + + # validate deepcopy() copies all public attributes + def getPublicAttr(obj): + return set(k for k in obj.__dict__ if not k.startswith('_')) + + assert_equal(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer))) + + +def _test_basic_cases_template(weight, bias, inputs, constructor, + scheduler_constructors): + """Copied from PyTorch.""" + weight = Variable(weight, requires_grad=True) + bias = Variable(bias, requires_grad=True) + inputs = Variable(inputs) + optimizer = constructor(weight, bias) + schedulers = [] + for scheduler_constructor in scheduler_constructors: + schedulers.append(scheduler_constructor(optimizer)) + + # to check if the optimizer can be printed as a string + optimizer.__repr__() + + def fn(): + optimizer.zero_grad() + y = weight.mv(inputs) + if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device(): + y = y.cuda(bias.get_device()) + loss = (y + bias).pow(2).sum() + loss.backward() + return loss + + initial_value = fn().item() + for _ in range(200): + for scheduler in schedulers: + scheduler.step() + optimizer.step(fn) + + assert fn().item() < initial_value + + +def _test_basic_cases(constructor, + scheduler_constructors=None, + ignore_multidevice=False): + """Copied from PyTorch.""" + if scheduler_constructors is None: + scheduler_constructors = [] + _test_state_dict( + torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor) + _test_basic_cases_template( + torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor, + scheduler_constructors) + # non-contiguous parameters + _test_basic_cases_template( + torch.randn(10, 5, 2)[..., 0], + torch.randn(10, 2)[..., 0], torch.randn(5), constructor, + scheduler_constructors) + # CUDA + if not torch.cuda.is_available(): + return + _test_basic_cases_template( + torch.randn(10, 5).cuda(), + torch.randn(10).cuda(), + torch.randn(5).cuda(), constructor, scheduler_constructors) + # Multi-GPU + if not torch.cuda.device_count() > 1 or ignore_multidevice: + return + _test_basic_cases_template( + torch.randn(10, 5).cuda(0), + torch.randn(10).cuda(1), + torch.randn(5).cuda(0), constructor, scheduler_constructors) + + +def test_lamb_optimizer(): + model = ExampleModel() + optimizer_cfg = dict( + type='Lamb', + lr=base_lr, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=base_wd, + paramwise_cfg=dict( + bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1)) + optimizer = build_optimizer(model, optimizer_cfg) + check_lamb_optimizer(optimizer, model, **optimizer_cfg['paramwise_cfg']) + + _test_basic_cases(lambda weight, bias: build_from_cfg( + dict(type='Lamb', params=[weight, bias], lr=base_lr), OPTIMIZERS))