Skip to content

Commit

Permalink
Merge branch 'lr_reduce' of https://github.com/gengenkai/mmcv into lr…
Browse files Browse the repository at this point in the history
…_reduce
  • Loading branch information
gengenkai committed Jul 8, 2021
2 parents db580dd + 4cdf338 commit f229de3
Show file tree
Hide file tree
Showing 2 changed files with 414 additions and 0 deletions.
221 changes: 221 additions & 0 deletions mmcv/runner/hooks/lr_updater.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Open-MMLab. All rights reserved.
import numbers
from math import cos, pi
from typing import Optional

import mmcv
from .hook import HOOKS, Hook
Expand Down Expand Up @@ -614,3 +615,223 @@ def format_param(name, optim, param):
if name not in param:
raise KeyError(f'{name} is not found in {param.keys()}')
return param[name]


@HOOKS.register_module()
class ReduceLrUpdateHook(LrUpdaterHook):
"""ReduceLROnPlateau Scheduler.
Reduce learning rate when a metric has stopped improving. This scheduler
reads a metrics quantity and if no improvement is seen for a 'patience'
number of epochs, the learning rate is reduced.
Args:
periods (list[int]): Periods that taking the metric value in count.
val_metric (str, optional): Metrics to be evaluated. If val_metric is
None, the metrics will be loss value. Default: None.
mode (str, optional): One of `min`, `max`. In `min` mode, lr will
be reduced when the quantity monitored has stopped
decreasing; in `max` mode it will be reduced when the
quantity monitored has stopped increasing. Default: 'min'.
factor (float, optional): Factor by which the learning rate will be
reduced. new_lr = lr * factor. Default: 0.1.
patience (int, optional): Number of epochs with no improvement after
which learning rate will be reduced. For example, if
`patience = 2`, then we will ignore the first 2 epochs
with no improvement, and will only decrease the LR after the
3rd epoch if the loss still hasn't improved then.
Default: 10.
threshold (float, optional): Threshold for measuring the new optimum,
to only focus on significant changes. Default: 1e-4.
threshold_mode (str, optional): One of `rel`, `abs`. In `rel` mode,
dynamic_threshold = best * ( 1 + threshold ) in 'max'
mode or best * ( 1 - threshold ) in `min` mode.
In `abs` mode, dynamic_threshold = best + threshold in
`max` mode or best - threshold in `min` mode. Default: 'rel'.
cooldown (int, optional): Number of epochs to wait before resuming
normal operation after lr has been reduced. Default: 0.
min_lr (float, optional): Minimum LR value to keep. If LR after decay
is lower than `min_lr`, it will be clipped to this value.
Default: 0.
eps (float, optional): Minimal decay applied to lr. If the difference
between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
"""

def __init__(self,
periods: list,
val_metric: Optional[str] = None,
mode: str = 'min',
factor: float = 0.1,
patience: int = 10,
threshold: float = 1e-4,
threshold_mode: str = 'rel',
cooldown: int = 0,
min_lr: float = 0.,
eps: float = 1e-8,
**kwargs):
assert isinstance(periods, list), '"periods" must be a list'
assert mmcv.is_list_of(periods, int) and all([s >= 0 for s in periods])
self.periods = periods
self.val_metric = val_metric

if mode not in ['min', 'max']:
raise ValueError(
'mode must be one of "min" or "max", instead got {mode}')
self.mode = mode

if factor >= 1.0:
raise ValueError('Factor should be < 1.0')
self.factor = factor

self.patience = patience
self.threshold = threshold

if threshold_mode not in ['rel', 'abs']:
raise ValueError('thresh_mode must be one of "rel" or "abs",\
instead got {threshold_mode}')
self.threshold_mode = threshold_mode

self.cooldown = cooldown
self.cooldown_counter = 0
self.best = None
self.num_bad_epochs = None
self.mode_worse = None # the worse value for the chosen mode
self.min_lr = min_lr
self.eps = eps
self.last_epoch = 0
self._init_is_better(self.mode)
self._reset()
super(ReduceLrUpdateHook, self).__init__(**kwargs)

def get_lr(self, runner, regular_lr):
if self.num_bad_epochs > self.patience:
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
if regular_lr - regular_lr * self.factor > self.eps:
new_lr = max(regular_lr * self.factor, self.min_lr)
else:
new_lr = regular_lr
return new_lr
else:
return regular_lr

def get_regular_lr(self, runner):
if not self.regular_lr:
self.regular_lr = self.base_lr
if isinstance(runner.optimizer, dict):
lr_groups = {}
for k in runner.optimizer.keys():
_lr_group = [
self.get_lr(runner, _regular_lr)
for _regular_lr in self.regular_lr[k]
]
lr_groups.update({k: _lr_group})
return lr_groups
else:
return [
self.get_lr(runner, _regular_lr)
for _regular_lr in self.regular_lr
]

def _init_is_better(self, mode):
if mode == 'min':
self.mode_worse = float('inf')
else:
self.mode_worse = float('-inf')

def _reset(self):
self.best = self.mode_worse
self.cooldown_counter = 0
self.num_bad_epochs = 0

def is_better(self, a, best):
if self.mode == 'min' and self.threshold_mode == 'rel':
rel_epsilon = 1. - self.threshold
return a < best * rel_epsilon
elif self.mode == 'min' and self.threshold_mode == 'abs':
return a < best - self.threshold
elif self.mode == 'max' and self.threshold_mode == 'rel':
rel_epsilon = 1. + self.threshold
return a > best * rel_epsilon
else:
return a > best + self.threshold

@property
def in_cooldown(self):
return self.cooldown_counter > 0

def after_train_epoch(self, runner):
if not self.by_epoch:
return
cur_epoch = runner.epoch
if self.warmup is not None and self.warmup_by_epoch:
if cur_epoch <= self.warmup_epochs:
return
if cur_epoch in self.periods and self.val_metric is None:
current = runner.outputs['loss']
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0
print('epoch--', cur_epoch, ' lr:', self.regular_lr)

def after_train_iter(self, runner):
if self.by_epoch:
return
cur_iter = runner.iter
if self.warmup_epochs is not None and cur_iter <= self.warmup_iters:
return
if cur_iter in self.periods and self.val_metric is None:
current = runner.outputs['loss']
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0

def after_val_epoch(self, runner):
if not self.by_epoch:
return
cur_epoch = runner.epoch
if self.warmup is not None and self.warmup_by_epoch:
if cur_epoch <= self.warmup_epochs:
return
if cur_epoch in self.periods and self.val_metric is not None:
current = runner.outputs[self.val_metric]
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0

def after_val_iter(self, runner):
if self.by_epoch:
return
cur_iter = runner.iter
if self.warmup_epochs is not None and cur_iter <= self.warmup_iters:
return
if cur_iter in self.periods and self.val_metric is not None:
current = runner.outputs[self.val_metric]
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0
Loading

0 comments on commit f229de3

Please sign in to comment.