-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Support LR_ReduceOnPlateau #1033
Changes from 12 commits
7eb4b0a
05d8f8d
64baa50
1a66977
17c8e0f
7c1fe21
5be9593
d3bbfdb
e9f2a02
4bd3b50
55b4847
9d1436f
13888df
8fcc1ff
a637724
732ff50
6c7d6c3
4d42365
e728608
61c656e
65fa5e9
373f712
8983e7f
8111525
5c754e7
717d157
bf2c9fa
69a4316
19c8017
3188757
50c255b
b028a1f
50537dd
d212bd5
f88c0b9
bdd7022
1076958
a88d1d2
69146fe
11629d5
e05fb56
3d7bcc8
c9b009f
6cb534b
a5d4c65
088fde3
1b59409
004c006
6fd6ada
f71e47c
f7caa80
59ed0dd
9c26a10
303aa7f
49a1d34
560719d
eb08835
d9effbd
6fe3722
7b150fa
94818ad
227e7a7
1b15f02
db097bd
797ef57
76d9bf1
21845db
1d5ee6e
b035fe9
7e285d3
cdcbc03
66cefaf
6c63621
4a9f834
23d7fc8
31b8829
0cbe5f4
db580dd
1728a11
287c1ea
4cdf338
f229de3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -614,3 +614,221 @@ def format_param(name, optim, param): | |
if name not in param: | ||
raise KeyError(f'{name} is not found in {param.keys()}') | ||
return param[name] | ||
|
||
|
||
@HOOKS.register_module() | ||
class ReduceLrUpdateHook(LrUpdaterHook): | ||
"""ReduceLROnPlateau Scheduler. | ||
|
||
Reduce learning rate when a metric has stopped improving. This scheduler | ||
reads a metrics quantity and if no improvement is seen for a 'patience' | ||
number of epochs, the learning rate is reduced. | ||
|
||
Args: | ||
periods (list[int]): Periods that taking the metric value in count. | ||
val_metric (string): The key of the validation metric in outputs. If | ||
val_metric is None, the metrics will be loss value. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lack of indent |
||
mode (str): One of `min`, `max`. In `min` mode, lr will | ||
be reduced when the quantity monitored has stopped | ||
decreasing; in `max` mode it will be reduced when the | ||
quantity monitored has stopped increasing. Default: 'min'. | ||
factor (float): Factor by which the learning rate will be | ||
reduced. new_lr = lr * factor. Default: 0.1. | ||
patience (int): Number of epochs with no improvement after | ||
which learning rate will be reduced. For example, if | ||
`patience = 2`, then we will ignore the first 2 epochs | ||
with no improvement, and will only decrease the LR after the | ||
3rd epoch if the loss still hasn't improved then. | ||
Default: 10. | ||
threshold (float): Threshold for measuring the new optimum, | ||
to only focus on significant changes. Default: 1e-4. | ||
threshold_mode (str): One of `rel`, `abs`. In `rel` mode, | ||
dynamic_threshold = best * ( 1 + threshold ) in 'max' | ||
mode or best * ( 1 - threshold ) in `min` mode. | ||
In `abs` mode, dynamic_threshold = best + threshold in | ||
`max` mode or best - threshold in `min` mode. Default: 'rel'. | ||
cooldown (int): Number of epochs to wait before resuming | ||
normal operation after lr has been reduced. Default: 0. | ||
min_lr (float, optional): Minimum LR value to keep. If LR after decay | ||
is lower than `min_lr`, it will be clipped to this value. | ||
Default: 0. | ||
eps (float): Minimal decay applied to lr. If the difference | ||
between new and old lr is smaller than eps, the update is | ||
ignored. Default: 1e-8. | ||
""" | ||
|
||
zhouzaida marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, | ||
periods, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add type hint for arguments |
||
val_metric=None, | ||
mode='min', | ||
factor=0.1, | ||
patience=10, | ||
threshold=1e-4, | ||
threshold_mode='rel', | ||
cooldown=0, | ||
min_lr=0., | ||
eps=1e-8, | ||
**kwargs): | ||
if isinstance(periods, list): | ||
assert mmcv.is_list_of(periods, int) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should add arguments like |
||
assert all([s >= 0 for s in periods]) | ||
else: | ||
raise TypeError('"periods" must be a list') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line #L676 was not covered by tests. |
||
self.periods = periods | ||
self.val_metric = val_metric | ||
if mode not in ['min', 'max']: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is readable to add a blank line between different logical code |
||
raise ValueError( | ||
'mode must be one of "min" or "max", instead got {mode}') | ||
self.mode = mode | ||
if factor >= 1.0: | ||
raise ValueError('Factor should be < 1.0') | ||
self.factor = factor | ||
self.patience = patience | ||
self.threshold = threshold | ||
if threshold_mode not in ['rel', 'abs']: | ||
raise ValueError('thresh_mode must be one of "rel" or "abs",\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line #L689 was not covered by tests |
||
instead got {threshold_mode}') | ||
self.threshold_mode = threshold_mode | ||
self.cooldown = cooldown | ||
self.cooldown_counter = 0 | ||
self.best = None | ||
self.num_bad_epochs = None | ||
self.mode_worse = None # the worse value for the chosen mode | ||
self.min_lr = min_lr | ||
self.eps = eps | ||
self.last_epoch = 0 | ||
self._init_is_better(self.mode) | ||
self._reset() | ||
super(ReduceLrUpdateHook, self).__init__(**kwargs) | ||
|
||
def get_lr(self, runner, regular_lr): | ||
if self.num_bad_epochs > self.patience: | ||
self.cooldown_counter = self.cooldown | ||
self.num_bad_epochs = 0 | ||
if regular_lr - regular_lr * self.factor > self.eps: | ||
new_lr = max(regular_lr * self.factor, self.min_lr) | ||
else: | ||
new_lr = regular_lr | ||
return new_lr | ||
else: | ||
return regular_lr | ||
|
||
def get_regular_lr(self, runner): | ||
if not self.regular_lr: | ||
self.regular_lr = self.base_lr | ||
if isinstance(runner.optimizer, dict): | ||
lr_groups = {} | ||
for k in runner.optimizer.keys(): | ||
_lr_group = [ | ||
self.get_lr(runner, _regular_lr) | ||
for _regular_lr in self.regular_lr[k] | ||
] | ||
lr_groups.update({k: _lr_group}) | ||
return lr_groups | ||
else: | ||
return [ | ||
self.get_lr(runner, _regular_lr) | ||
for _regular_lr in self.regular_lr | ||
] | ||
|
||
def _init_is_better(self, mode): | ||
if mode == 'min': | ||
self.mode_worse = float('inf') | ||
else: | ||
self.mode_worse = float('-inf') | ||
|
||
def _reset(self): | ||
self.best = self.mode_worse | ||
self.cooldown_counter = 0 | ||
self.num_bad_epochs = 0 | ||
|
||
def is_better(self, a, best): | ||
if self.mode == 'min' and self.threshold_mode == 'rel': | ||
rel_epsilon = 1. - self.threshold | ||
return a < best * rel_epsilon | ||
elif self.mode == 'min' and self.threshold_mode == 'abs': | ||
return a < best - self.threshold | ||
elif self.mode == 'max' and self.threshold_mode == 'rel': | ||
rel_epsilon = 1. + self.threshold | ||
return a > best * rel_epsilon | ||
else: | ||
return a > best + self.threshold | ||
|
||
@property | ||
def in_cooldown(self): | ||
return self.cooldown_counter > 0 | ||
|
||
def after_train_epoch(self, runner): | ||
if not self.by_epoch: | ||
return | ||
cur_epoch = runner.epoch | ||
if self.warmup is not None and self.warmup_by_epoch: | ||
if cur_epoch <= self.warmup_epochs: | ||
return | ||
if cur_epoch in self.periods and self.val_metric is None: | ||
current = runner.outputs['loss'] | ||
if self.is_better(current, self.best): | ||
self.best = current | ||
self.num_bad_epochs = 0 | ||
else: | ||
self.num_bad_epochs += 1 | ||
|
||
if self.in_cooldown: | ||
self.cooldown_counter -= 1 | ||
self.num_bad_epochs = 0 | ||
print('epoch--', cur_epoch, ' lr:', self.regular_lr) | ||
|
||
def after_train_iter(self, runner): | ||
if self.by_epoch: | ||
return | ||
cur_iter = runner.iter | ||
if self.warmup_epochs is not None and cur_iter <= self.warmup_iters: | ||
return | ||
if cur_iter in self.periods and self.val_metric is None: | ||
current = runner.outputs['loss'] | ||
if self.is_better(current, self.best): | ||
self.best = current | ||
self.num_bad_epochs = 0 | ||
else: | ||
self.num_bad_epochs += 1 | ||
|
||
if self.in_cooldown: | ||
self.cooldown_counter -= 1 | ||
self.num_bad_epochs = 0 | ||
|
||
def after_val_epoch(self, runner): | ||
if not self.by_epoch: | ||
return | ||
cur_epoch = runner.epoch | ||
if self.warmup is not None and self.warmup_by_epoch: | ||
if cur_epoch <= self.warmup_epochs: | ||
return | ||
if cur_epoch in self.periods and self.val_metric is not None: | ||
current = runner.outputs[self.val_metric] | ||
if self.is_better(current, self.best): | ||
self.best = current | ||
self.num_bad_epochs = 0 | ||
else: | ||
self.num_bad_epochs += 1 | ||
|
||
if self.in_cooldown: | ||
self.cooldown_counter -= 1 | ||
self.num_bad_epochs = 0 | ||
|
||
def after_val_iter(self, runner): | ||
if self.by_epoch: | ||
return | ||
cur_iter = runner.iter | ||
if self.warmup_epochs is not None and cur_iter <= self.warmup_iters: | ||
return | ||
if cur_iter in self.periods and self.val_metric is not None: | ||
current = runner.outputs[self.val_metric] | ||
if self.is_better(current, self.best): | ||
self.best = current | ||
self.num_bad_epochs = 0 | ||
else: | ||
self.num_bad_epochs += 1 | ||
|
||
if self.in_cooldown: | ||
self.cooldown_counter -= 1 | ||
self.num_bad_epochs = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and the learning rate will be reduced if no improvement is observed for a
patience
number of epochs.