Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support LR_ReduceOnPlateau #1033

Closed
wants to merge 82 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
7eb4b0a
lr_reduce 0519
May 19, 2021
05d8f8d
add lr_reduce 0519
gengenkai May 19, 2021
64baa50
lr_reduce 0519
gengenkai May 19, 2021
1a66977
[fix]: Fix a bug where logs are missing when two or more loggers were…
ritosonn May 20, 2021
17c8e0f
[Feature] lr_reduce
gengenkai May 21, 2021
7c1fe21
[Feature] lr_reduce
gengenkai May 21, 2021
5be9593
[Fix] fix generalized attention fp16 (#1036)
AronLin May 23, 2021
d3bbfdb
fix parrots ci (#1032)
zhouzaida May 23, 2021
e9f2a02
add neptune.ai logger hook (#1025)
fcakyon May 23, 2021
4bd3b50
[Fix] Support amp (pytorch >= 1.6.0) on DCN and DCNv2/ Add unit tests…
AronLin May 23, 2021
55b4847
[Feature] Add truncated normal weight init (#935)
zhouzaida May 23, 2021
9d1436f
[Feature] add cummax/cummin tensorrt plugin (#1031)
May 24, 2021
13888df
Fix typos (#1041)
innerlee May 24, 2021
8fcc1ff
lr_reduce 0525
gengenkai May 25, 2021
a637724
[Fix] Delete commit id (#1045)
zhouzaida May 25, 2021
732ff50
Add ms_deformable_attn in parrots (#1042)
luopeichao May 25, 2021
6c7d6c3
Supports cuda version BorderAlign module (#1021)
v-qjqs May 25, 2021
4d42365
[Feature]: add TensorRT InstanceNormalization plugin (#1034)
RunningLeon May 25, 2021
e728608
Bump version to v1.3.5 (#1050)
zhouzaida May 25, 2021
61c656e
[WIP]test lr_reduce
gengenkai May 26, 2021
65fa5e9
test ReduceLR 0528
gengenkai May 28, 2021
373f712
Delete test_pt.py
gengenkai May 28, 2021
8983e7f
test ReduceLR 0528
gengenkai May 28, 2021
8111525
test ReduceLR 0528
gengenkai May 28, 2021
5c754e7
Merge branch 'lr_reduce' of github.com:gengenkai/mmcv into lr_reduce
gengenkai May 28, 2021
717d157
Add segmentation keys for greater_keys. (#1060)
yinchimaoliang May 31, 2021
bf2c9fa
[Feature] NMS update (#957)
SemyonBevzuk May 31, 2021
69a4316
fix format 0601
gengenkai Jun 1, 2021
19c8017
fix format 0601
gengenkai Jun 1, 2021
3188757
fix format 0601
gengenkai Jun 1, 2021
50c255b
[Feature] Support to use name of the base classes in init_cfg (#1057)
MeowZheng Jun 1, 2021
b028a1f
Fix mmcls link (#1067)
LXXXXR Jun 1, 2021
50537dd
Imporve windows support for list_from_file (#1043)
innerlee Jun 1, 2021
d212bd5
[Fix] Fix the docstring for initializers (#1071)
MeowZheng Jun 1, 2021
f88c0b9
fix docstring 0603
gengenkai Jun 3, 2021
bdd7022
Add DvcliveLoggerHook (#1075)
daavoo Jun 8, 2021
1076958
[Docs] Add runner tutorial (#1082)
LXXXXR Jun 9, 2021
a88d1d2
[Feature] enable exporting to onnx for PointRend (#953)
Jun 11, 2021
69146fe
add load_ext warning (#1089)
luopeichao Jun 11, 2021
11629d5
Support resume for fp16 training (#1013)
ycxioooong Jun 11, 2021
e05fb56
Refactor the baseclass related to transformer (#978)
jshilong Jun 11, 2021
3d7bcc8
add border_align support in parrots (#1086)
luopeichao Jun 11, 2021
c9b009f
fix typo (#1094)
LXXXXR Jun 11, 2021
6cb534b
bump version to v1.3.6 (#1095)
zhouzaida Jun 11, 2021
a5d4c65
add a return value in TextLoggerHook.log() (#1040)
cww97 Jun 14, 2021
088fde3
Avoid bc-breaking of importing `MultiScaleDeformableAttention` (#1100)
jshilong Jun 16, 2021
1b59409
bump version to v1.3.7 (#1103)
zhouzaida Jun 16, 2021
004c006
[Feature]: add modulated deformable conv TensorRT support (#1078)
Jun 16, 2021
6fd6ada
[Fix] Delete warning report (#1126)
MeowZheng Jun 22, 2021
f71e47c
fix typos (#1124)
Junjun2016 Jun 22, 2021
f7caa80
[Enhancement] Add to_ntuple (#1125)
Junjun2016 Jun 23, 2021
59ed0dd
Fix the dimension (#1117)
jaemin93 Jun 23, 2021
9c26a10
empty tensor inference backward compatible (#1131)
dreamerlin Jun 24, 2021
303aa7f
fix parrots cpu compile bug (#1129)
luopeichao Jun 24, 2021
49a1d34
[Enhancement] Support resize or rescale to multiple (#1121)
Junjun2016 Jun 24, 2021
560719d
EvalHook uses case-insensitive key indicator matching and configurabl…
ly015 Jun 24, 2021
eb08835
[Fix] Fix the permission denied error on windows. (#1077)
fjfzlzj Jun 24, 2021
d9effbd
Support variables in base files for configs (#1083)
innerlee Jun 25, 2021
6fe3722
Refine default hooks and custom hooks priority rank. (#1120)
mzr1996 Jun 25, 2021
7b150fa
[Feature] Optimize the PyTorch CUDA implementation for Criss Cross At…
Jun 25, 2021
94818ad
update ca_forward_kernel (#1144)
luopeichao Jun 25, 2021
227e7a7
Support image reading while ignoring EXIF orientation info (#1091)
gaotongxiao Jun 25, 2021
1b15f02
support print hooks before running. (#1123)
mzr1996 Jun 25, 2021
db097bd
bump version to v1.3.8 (#1148)
zhouzaida Jun 25, 2021
797ef57
[Fix] Fix SyncBN build in PyTorch 1.9 (#1138)
xvjiarui Jun 27, 2021
76d9bf1
[Docs] Fix error when cv2 is mocked (#1152)
zhouzaida Jun 28, 2021
21845db
[Fix]: fix missing check of directory in scandir (#1110)
achaiah Jun 29, 2021
1d5ee6e
use LooseVersion for version checking (#1158)
kennymckormick Jun 29, 2021
b035fe9
Change dict update order (#1108)
antoniolanza1996 Jun 29, 2021
7e285d3
[Fix] Fix saconv (#1147)
kennymckormick Jul 2, 2021
cdcbc03
add citation in readme.md (#1160)
zhouzaida Jul 2, 2021
66cefaf
[Docs] Refactor docs (#1102)
zhouzaida Jul 2, 2021
6c63621
[Docs] Fix typo (#1166)
zhouzaida Jul 3, 2021
4a9f834
[Fix] Fix unittest in pt1.9 (#1146)
zhouzaida Jul 3, 2021
23d7fc8
add CI for pt1.8 pt1.9 (#1141)
zhouzaida Jul 6, 2021
31b8829
Add pretrained mmdet/mobilenet_v2 (#1177)
RangiLyu Jul 7, 2021
0cbe5f4
[Fix] Fix sphinx compile error (#1176)
zhouzaida Jul 7, 2021
db580dd
[Docs] Build Chinese docs (#1073)
zhouzaida Jul 7, 2021
1728a11
fix bugs in lr_reduce test 0707
gengenkai Jul 7, 2021
287c1ea
fix test 0708
gengenkai Jul 8, 2021
4cdf338
fix test 0708
gengenkai Jul 8, 2021
f229de3
Merge branch 'lr_reduce' of https://github.com/gengenkai/mmcv into lr…
gengenkai Jul 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions mmcv/runner/hooks/lr_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,184 @@ def format_param(name, optim, param):
if name not in param:
raise KeyError(f'{name} is not found in {param.keys()}')
return param[name]


@HOOKS.register_module()
class ReduceLrUpdateHook(LrUpdaterHook):

zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self,
periods,
Copy link
Collaborator

@zhouzaida zhouzaida May 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add type hint for arguments

val_metric=None,
mode='min',
factor=0.1,
patience=10,
threshold=1e-4,
threshold_mode='rel',
cooldown=0,
min_lr=0.,
eps=1e-8,
**kwargs):
if isinstance(periods, list):
assert mmcv.is_list_of(periods, int)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add arguments like assert expression [, arguments]

assert all([s > 0 for s in periods])
else:
raise TypeError('"periods" must be a list')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #L676 was not covered by tests.

self.periods = periods
self.val_metric = val_metric
if mode not in ['min', 'max']:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is readable to add a blank line between different logical code

raise ValueError(
'mode must be one of "min" or "max", instead got {mode}')
self.mode = mode
if factor >= 1.0:
raise ValueError('Factor should be < 1.0')
self.factor = factor
self.patience = patience
self.threshold = threshold
if threshold_mode not in ['rel', 'abs']:
raise ValueError('thresh_mode must be one of "rel" or "abs",\
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #L689 was not covered by tests

instead got {threshold_mode}')
self.threshold_mode = threshold_mode
self.cooldown = cooldown
self.cooldown_counter = 0
self.best = None
self.num_bad_epochs = None
self.mode_worse = None # the worse value for the chosen mode
self.min_lr = min_lr
self.eps = eps
self.last_epoch = 0
self._init_is_better(self.mode)
self._reset()
super(ReduceLrUpdateHook, self).__init__(**kwargs)

def get_lr(self, runner, regular_lr):
if self.num_bad_epochs > self.patience:
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
if regular_lr - regular_lr * self.factor > self.eps:
new_lr = max(regular_lr * self.factor, self.min_lr)
else:
new_lr = regular_lr
return new_lr
else:
return regular_lr

def get_regular_lr(self, runner):
if isinstance(runner.optimizer, dict):
lr_groups = {}
for k in runner.optimizer.keys():
_lr_group = [
self.get_lr(runner, _regular_lr)
for _regular_lr in self.regular_lr[k]
]
lr_groups.update({k: _lr_group})
# self.regular_lr.update({k: _lr_group})
return lr_groups
else:
return [
self.get_lr(runner, _regular_lr)
for _regular_lr in self.regular_lr
]

def _init_is_better(self, mode):
if mode == 'min':
self.mode_worse = float('inf')
else:
self.mode_worse = float('-inf')

def _reset(self):
self.best = self.mode_worse
self.cooldown_counter = 0
self.num_bad_epochs = 0

def is_better(self, a, best):
if self.mode == 'min' and self.threshold_mode == 'rel':
rel_epsilon = 1. - self.threshold
return a < best * rel_epsilon

elif self.mode == 'min' and self.threshold_mode == 'abs':
return a < best - self.threshold

elif self.mode == 'max' and self.threshold_mode == 'rel':
rel_epsilon = 1. + self.threshold
return a > best * rel_epsilon

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the blank line

else:
return a > best + self.threshold

@property
def in_cooldown(self):
return self.cooldown_counter > 0

def after_train_epoch(self, runner):
if not self.by_epoch:
return
cur_epoch = runner.epoch
if self.warmup is not None and self.warmup_by_epoch:
if cur_epoch <= self.warmup_epochs:
return
if cur_epoch in self.periods and self.val_metric is None:
current = runner.outputs.loss
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0

def after_train_iter(self, runner):
if self.by_epoch:
return
cur_iter = runner.iter
if self.warmup_epochs is not None and cur_iter <= self.warmup_iters:
return
if cur_iter in self.periods and self.val_metric is None:
current = runner.outputs.loss
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0

def after_val_epoch(self, runner):
if not self.by_epoch:
return
cur_epoch = runner.epoch
if self.warmup is not None and self.warmup_by_epoch:
if cur_epoch <= self.warmup_epochs:
return
if cur_epoch in self.periods and self.val_metric is not None:
current = runner.outputs[self.val_metric]
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0

def after_val_iter(self, runner):
if self.by_epoch:
return
cur_iter = runner.iter
if self.warmup_epochs is not None and cur_iter <= self.warmup_iters:
return
if cur_iter in self.periods and self.val_metric is not None:
current = runner.outputs[self.val_metric]
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0
48 changes: 48 additions & 0 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
import torch.nn as nn
from torch.nn.init import constant_
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
Expand All @@ -25,6 +26,7 @@
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook,
OneCycleLrUpdaterHook,
ReduceLrUpdateHook,
StepLrUpdaterHook)


Expand Down Expand Up @@ -869,6 +871,52 @@ def test_cyclic_lr_update_hook(multi_optimizers, max_iters):
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


@pytest.mark.parametrize('multi_optimziers', (True, False))
def test_reduce_lr_update_hook(multi_optimziers):
"""Test ReduceLrUpdateHook."""
with pytest.raises(TypeError):
# periods should be specified
ReduceLrUpdateHook()

with pytest.raises(AssertionError):
# periods should all be positive
ReduceLrUpdateHook(periods=[1, 2, -2])

with pytest.raises(ValueError):
# mode should be either 'min' or 'max'
ReduceLrUpdateHook(periods=[0, 1], mode='sum')

with pytest.raises(ValueError):
# factor should be < 1.0
ReduceLrUpdateHook(periods=[0, 1], mode='min', factor=1.0)

with pytest.raises(ValueError):
# threshold_mode should be 'rel' or 'abs'
ReduceLrUpdateHook(
periods=[0, 1], mode='min', factor=1.0, threshold_mode='sum')

sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner(multi_optimziers=multi_optimziers)

hook = ReduceLROnPlateau(
periods=list(range(20)), mode='min', factor=0.1, patience=2)
runner.register_hook(hook)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)

assert hasattr(hook, 'writer')
if multi_optimziers:
pass
else:
pass


@pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model):
sys.modules['mlflow'] = MagicMock()
Expand Down