Skip to content

Commit

Permalink
fix unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
wangg12 committed Jun 1, 2021
1 parent 86e767f commit 2adbb22
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,38 +388,52 @@ def test_flat_cosine_runner_hook(multi_optimziers):
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call('train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
}, 1),
call('train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
}, 6),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 6),
call(
'train', {
'learning_rate/model1': 0.018090169943749474,
'learning_rate/model2': 0.009045084971874737,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 7),
call(
'train', {
'learning_rate/model1': 0.0019098300562505265,
'learning_rate/model2': 0.0009549150281252633,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 10)
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 6),
call('train', {
'learning_rate': 0.018090169943749474,
'momentum': 0.95
}, 7),
call('train', {
'learning_rate': 0.0019098300562505265,
'momentum': 0.95
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
Expand Down

0 comments on commit 2adbb22

Please sign in to comment.