diff --git a/mmcv/runner/hooks/lr_updater.py b/mmcv/runner/hooks/lr_updater.py index 4f9a65905b..917c58c9bc 100644 --- a/mmcv/runner/hooks/lr_updater.py +++ b/mmcv/runner/hooks/lr_updater.py @@ -1,6 +1,7 @@ # Copyright (c) Open-MMLab. All rights reserved. import numbers from math import cos, pi +from typing import Optional import mmcv from .hook import HOOKS, Hook @@ -614,3 +615,223 @@ def format_param(name, optim, param): if name not in param: raise KeyError(f'{name} is not found in {param.keys()}') return param[name] + + +@HOOKS.register_module() +class ReduceLrUpdateHook(LrUpdaterHook): + """ReduceLROnPlateau Scheduler. + + Reduce learning rate when a metric has stopped improving. This scheduler + reads a metrics quantity and if no improvement is seen for a 'patience' + number of epochs, the learning rate is reduced. + + Args: + periods (list[int]): Periods that taking the metric value in count. + val_metric (str, optional): Metrics to be evaluated. If val_metric is + None, the metrics will be loss value. Default: None. + mode (str, optional): One of `min`, `max`. In `min` mode, lr will + be reduced when the quantity monitored has stopped + decreasing; in `max` mode it will be reduced when the + quantity monitored has stopped increasing. Default: 'min'. + factor (float, optional): Factor by which the learning rate will be + reduced. new_lr = lr * factor. Default: 0.1. + patience (int, optional): Number of epochs with no improvement after + which learning rate will be reduced. For example, if + `patience = 2`, then we will ignore the first 2 epochs + with no improvement, and will only decrease the LR after the + 3rd epoch if the loss still hasn't improved then. + Default: 10. + threshold (float, optional): Threshold for measuring the new optimum, + to only focus on significant changes. Default: 1e-4. + threshold_mode (str, optional): One of `rel`, `abs`. In `rel` mode, + dynamic_threshold = best * ( 1 + threshold ) in 'max' + mode or best * ( 1 - threshold ) in `min` mode. + In `abs` mode, dynamic_threshold = best + threshold in + `max` mode or best - threshold in `min` mode. Default: 'rel'. + cooldown (int, optional): Number of epochs to wait before resuming + normal operation after lr has been reduced. Default: 0. + min_lr (float, optional): Minimum LR value to keep. If LR after decay + is lower than `min_lr`, it will be clipped to this value. + Default: 0. + eps (float, optional): Minimal decay applied to lr. If the difference + between new and old lr is smaller than eps, the update is + ignored. Default: 1e-8. + """ + + def __init__(self, + periods: list, + val_metric: Optional[str] = None, + mode: str = 'min', + factor: float = 0.1, + patience: int = 10, + threshold: float = 1e-4, + threshold_mode: str = 'rel', + cooldown: int = 0, + min_lr: float = 0., + eps: float = 1e-8, + **kwargs): + assert isinstance(periods, list), '"periods" must be a list' + assert mmcv.is_list_of(periods, int) and all([s >= 0 for s in periods]) + self.periods = periods + self.val_metric = val_metric + + if mode not in ['min', 'max']: + raise ValueError( + 'mode must be one of "min" or "max", instead got {mode}') + self.mode = mode + + if factor >= 1.0: + raise ValueError('Factor should be < 1.0') + self.factor = factor + + self.patience = patience + self.threshold = threshold + + if threshold_mode not in ['rel', 'abs']: + raise ValueError('thresh_mode must be one of "rel" or "abs",\ + instead got {threshold_mode}') + self.threshold_mode = threshold_mode + + self.cooldown = cooldown + self.cooldown_counter = 0 + self.best = None + self.num_bad_epochs = None + self.mode_worse = None # the worse value for the chosen mode + self.min_lr = min_lr + self.eps = eps + self.last_epoch = 0 + self._init_is_better(self.mode) + self._reset() + super(ReduceLrUpdateHook, self).__init__(**kwargs) + + def get_lr(self, runner, regular_lr): + if self.num_bad_epochs > self.patience: + self.cooldown_counter = self.cooldown + self.num_bad_epochs = 0 + if regular_lr - regular_lr * self.factor > self.eps: + new_lr = max(regular_lr * self.factor, self.min_lr) + else: + new_lr = regular_lr + return new_lr + else: + return regular_lr + + def get_regular_lr(self, runner): + if not self.regular_lr: + self.regular_lr = self.base_lr + if isinstance(runner.optimizer, dict): + lr_groups = {} + for k in runner.optimizer.keys(): + _lr_group = [ + self.get_lr(runner, _regular_lr) + for _regular_lr in self.regular_lr[k] + ] + lr_groups.update({k: _lr_group}) + return lr_groups + else: + return [ + self.get_lr(runner, _regular_lr) + for _regular_lr in self.regular_lr + ] + + def _init_is_better(self, mode): + if mode == 'min': + self.mode_worse = float('inf') + else: + self.mode_worse = float('-inf') + + def _reset(self): + self.best = self.mode_worse + self.cooldown_counter = 0 + self.num_bad_epochs = 0 + + def is_better(self, a, best): + if self.mode == 'min' and self.threshold_mode == 'rel': + rel_epsilon = 1. - self.threshold + return a < best * rel_epsilon + elif self.mode == 'min' and self.threshold_mode == 'abs': + return a < best - self.threshold + elif self.mode == 'max' and self.threshold_mode == 'rel': + rel_epsilon = 1. + self.threshold + return a > best * rel_epsilon + else: + return a > best + self.threshold + + @property + def in_cooldown(self): + return self.cooldown_counter > 0 + + def after_train_epoch(self, runner): + if not self.by_epoch: + return + cur_epoch = runner.epoch + if self.warmup is not None and self.warmup_by_epoch: + if cur_epoch <= self.warmup_epochs: + return + if cur_epoch in self.periods and self.val_metric is None: + current = runner.outputs['loss'] + if self.is_better(current, self.best): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 + print('epoch--', cur_epoch, ' lr:', self.regular_lr) + + def after_train_iter(self, runner): + if self.by_epoch: + return + cur_iter = runner.iter + if self.warmup_epochs is not None and cur_iter <= self.warmup_iters: + return + if cur_iter in self.periods and self.val_metric is None: + current = runner.outputs['loss'] + if self.is_better(current, self.best): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 + + def after_val_epoch(self, runner): + if not self.by_epoch: + return + cur_epoch = runner.epoch + if self.warmup is not None and self.warmup_by_epoch: + if cur_epoch <= self.warmup_epochs: + return + if cur_epoch in self.periods and self.val_metric is not None: + current = runner.outputs[self.val_metric] + if self.is_better(current, self.best): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 + + def after_val_iter(self, runner): + if self.by_epoch: + return + cur_iter = runner.iter + if self.warmup_epochs is not None and cur_iter <= self.warmup_iters: + return + if cur_iter in self.periods and self.val_metric is not None: + current = runner.outputs[self.val_metric] + if self.is_better(current, self.best): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index 2cc010617b..e084ba5921 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -16,6 +16,7 @@ import pytest import torch import torch.nn as nn +import torch.utils.data as Data from torch.nn.init import constant_ from torch.utils.data import DataLoader @@ -26,6 +27,7 @@ from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook, CyclicLrUpdaterHook, OneCycleLrUpdaterHook, + ReduceLrUpdateHook, StepLrUpdaterHook) @@ -1087,3 +1089,194 @@ def after_epoch(): # stages output have order, so here is list instead of set. expected_stages = ['before_run', 'after_train_epoch', 'after_val_epoch'] assert hook.get_triggered_stages() == expected_stages + + +@pytest.mark.parametrize('multi_optimziers', (True, False)) +def test_reduce_lr_update_hook(multi_optimziers): + """Test ReduceLrUpdateHook.""" + with pytest.raises(TypeError): + # periods should be specified + ReduceLrUpdateHook() + + with pytest.raises(AssertionError): + # periods should be list + ReduceLrUpdateHook(periods=1) + + with pytest.raises(AssertionError): + # periods should all be positive + ReduceLrUpdateHook(periods=[1, 2, -2]) + + with pytest.raises(ValueError): + # mode should be either 'min' or 'max' + ReduceLrUpdateHook(periods=[0, 1], mode='sum') + + with pytest.raises(ValueError): + # factor should be < 1.0 + ReduceLrUpdateHook(periods=[0, 1], mode='min', factor=1.0) + + with pytest.raises(ValueError): + # threshold_mode should be 'rel' or 'abs' + ReduceLrUpdateHook( + periods=[0, 1], mode='min', factor=0.1, threshold_mode='sum') + + sys.modules['pavi'] = MagicMock() + x = torch.ones((30, 1)) + y = torch.ones((30, 1)) * 5 + loader = DataLoader(Data.TensorDataset(x, y)) + runner = _build_reduceLR_runner( + runner_type='IterBasedRunner', + multi_optimziers=multi_optimziers, + max_iters=30, + max_epochs=None) + + hook = ReduceLrUpdateHook( + periods=list(range(30)), + mode='min', + factor=0.1, + patience=2, + threshold=1e-4, + threshold_mode='rel', + by_epoch=False, + eps=1e-4) + runner.register_hook(hook) + runner.register_hook_from_cfg(dict(type='IterTimerHook')) + runner.register_hook(IterTimerHook()) + # add pavi hook + hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) + runner.register_hook(hook) + runner.run([loader], [('train', 1)]) + shutil.rmtree(runner.work_dir) + + assert hasattr(hook, 'writer') + if multi_optimziers: + calls = [ + call( + 'train', { + 'learning_rate/model1': 0.5, + 'learning_rate/model2': 0.01, + 'momentum/model1': 0.9, + 'momentum/model2': 0.95, + }, 1), + call( + 'train', { + 'learning_rate/model1': 0.05, + 'learning_rate/model2': 0.01, + 'momentum/model1': 0.9, + 'momentum/model2': 0.95, + }, 19), + call( + 'train', { + 'learning_rate/model1': 0.005000000000000001, + 'learning_rate/model2': 0.01, + 'momentum/model1': 0.9, + 'momentum/model2': 0.95, + }, 22), + call( + 'train', { + 'learning_rate/model1': 5.0000000000000016e-05, + 'learning_rate/model2': 0.01, + 'momentum/model1': 0.9, + 'momentum/model2': 0.95, + }, 28) + ] + else: + calls = [ + call('train', { + 'learning_rate': 0.5, + 'momentum': 0.9 + }, 1), + call('train', { + 'learning_rate': 0.05, + 'momentum': 0.9 + }, 19), + call('train', { + 'learning_rate': 0.005000000000000001, + 'momentum': 0.9 + }, 22), + call('train', { + 'learning_rate': 5.0000000000000016e-05, + 'momentum': 0.9 + }, 28) + ] + hook.writer.add_scalars.assert_has_calls(calls, any_order=True) + + +def _build_reduceLR_runner_without_hook(runner_type='EpochBasedRunner', + max_epochs=1, + max_iters=None, + multi_optimziers=False): + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 1) + self.conv = nn.Conv2d(3, 3, 3) + torch.nn.init.constant_(self.linear.weight, 1) + torch.nn.init.constant_(self.linear.bias, 1) + + def forward(self, x): + return self.linear(x) + + def train_step(self, x, optimizer, **kwargs): + if isinstance(optimizer, dict): + for name, optim in optimizer.items(): + optim.zero_grad() + else: + optimizer.zero_grad() + loss_fn = torch.nn.MSELoss() + pred = self.forward(x[0]) + loss_ = loss_fn(pred, x[1]) + loss_.backward() + if isinstance(optimizer, dict): + for name, optim in optimizer.items(): + optim.step() + else: + optimizer.step() + return dict(loss=loss_) + + def val_step(self, x, optimizer, **kwargs): + loss_fn = torch.nn.MSELoss() + return dict(loss=loss_fn(self.forward(x[0]), x[1])) + + model = Model() + + if multi_optimziers: + optimizer = { + 'model1': + torch.optim.SGD(model.linear.parameters(), lr=0.5, momentum=0.9), + 'model2': + torch.optim.SGD(model.conv.parameters(), lr=0.01, momentum=0.95), + } + else: + optimizer = torch.optim.SGD(model.parameters(), lr=0.5, momentum=0.9) + + tmp_dir = tempfile.mkdtemp() + runner = build_runner( + dict(type=runner_type), + default_args=dict( + model=model, + work_dir=tmp_dir, + optimizer=optimizer, + logger=logging.getLogger(), + max_epochs=max_epochs, + max_iters=max_iters)) + return runner + + +def _build_reduceLR_runner(runner_type='EpochBasedRunner', + max_epochs=1, + max_iters=None, + multi_optimziers=False): + + log_config = dict( + interval=1, hooks=[ + dict(type='TextLoggerHook'), + ]) + + runner = _build_reduceLR_runner_without_hook(runner_type, max_epochs, + max_iters, multi_optimziers) + + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_logger_hooks(log_config) + return runner