Skip to content

Commit

Permalink
add by_epoch=True unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
wangg12 committed Jun 2, 2021
1 parent 2adbb22 commit 2b6bef6
Showing 1 changed file with 108 additions and 50 deletions.
158 changes: 108 additions & 50 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,17 @@ def test_cosine_runner_hook(multi_optimziers):
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


@pytest.mark.parametrize('multi_optimziers', (True, False))
def test_flat_cosine_runner_hook(multi_optimziers):
@pytest.mark.parametrize('multi_optimziers, by_epoch', [(False, False),
(True, False),
(False, True),
(True, True)])
def test_flat_cosine_runner_hook(multi_optimziers, by_epoch):
"""xdoctest -m tests/test_hooks.py test_flat_cosine_runner_hook."""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
max_epochs = 10 if by_epoch else 1
runner = _build_demo_runner(
multi_optimziers=multi_optimziers, max_epochs=max_epochs)

with pytest.raises(ValueError):
# start_pct: expected float between 0 and 1
Expand All @@ -370,9 +375,10 @@ def test_flat_cosine_runner_hook(multi_optimziers):
# add LR scheduler
hook_cfg = dict(
type='FlatCosineAnnealingLrUpdaterHook',
by_epoch=False,
by_epoch=by_epoch,
min_lr_ratio=0,
warmup_iters=2,
warmup='linear',
warmup_iters=10 if by_epoch else 2,
warmup_ratio=0.9,
start_pct=0.5)
runner.register_hook_from_cfg(hook_cfg)
Expand All @@ -387,55 +393,107 @@ def test_flat_cosine_runner_hook(multi_optimziers):
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9
if by_epoch:
calls = [
call(
'train', {
'learning_rate/model1': 0.018000000000000002,
'learning_rate/model2': 0.009000000000000001,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 11),
call(
'train', {
'learning_rate/model1': 0.018090169943749474,
'learning_rate/model2': 0.009045084971874737,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 61),
call(
'train', {
'learning_rate/model1': 0.0019098300562505265,
'learning_rate/model2': 0.0009549150281252633,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 100)
]
else:
calls = [
call(
'train', {
'learning_rate/model1': 0.018000000000000002,
'learning_rate/model2': 0.009000000000000001,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 6),
call(
'train', {
'learning_rate/model1': 0.018090169943749474,
'learning_rate/model2': 0.009045084971874737,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 7),
call(
'train', {
'learning_rate/model1': 0.0019098300562505265,
'learning_rate/model2': 0.0009549150281252633,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 10)
]
else:
if by_epoch:
calls = [
call('train', {
'learning_rate': 0.018000000000000002,
'momentum': 0.95
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 11),
call('train', {
'learning_rate': 0.018090169943749474,
'momentum': 0.95
}, 61),
call('train', {
'learning_rate': 0.0019098300562505265,
'momentum': 0.95
}, 100)
]
else:
calls = [
call('train', {
'learning_rate': 0.018000000000000002,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 6),
call(
'train', {
'learning_rate/model1': 0.018090169943749474,
'learning_rate/model2': 0.009045084971874737,
'momentum/model1': 0.95,
'momentum/model2': 0.9
call('train', {
'learning_rate': 0.018090169943749474,
'momentum': 0.95
}, 7),
call(
'train', {
'learning_rate/model1': 0.0019098300562505265,
'learning_rate/model2': 0.0009549150281252633,
'momentum/model1': 0.95,
'momentum/model2': 0.9
call('train', {
'learning_rate': 0.0019098300562505265,
'momentum': 0.95
}, 10)
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 6),
call('train', {
'learning_rate': 0.018090169943749474,
'momentum': 0.95
}, 7),
call('train', {
'learning_rate': 0.0019098300562505265,
'momentum': 0.95
}, 10)
]
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


Expand Down

0 comments on commit 2b6bef6

Please sign in to comment.