Skip to content

Commit

Permalink
Add DvcliveLoggerHook (#1075)
Browse files Browse the repository at this point in the history
* Add dvclive logger hook

* Move docstring to class

* docstring updates
  • Loading branch information
daavoo authored Jun 8, 2021
1 parent d212bd5 commit bdd7022
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 19 deletions.
20 changes: 10 additions & 10 deletions mmcv/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
DistSamplerSeedHook, EMAHook, EvalHook, Fp16OptimizerHook,
Hook, IterTimerHook, LoggerHook, LrUpdaterHook,
MlflowLoggerHook, NeptuneLoggerHook, OptimizerHook,
PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook,
LrUpdaterHook, MlflowLoggerHook, NeptuneLoggerHook,
OptimizerHook, PaviLoggerHook, SyncBuffersHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .iter_based_runner import IterBasedRunner, IterLoader
from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
Expand All @@ -29,11 +29,11 @@
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
'_load_checkpoint', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
'save_checkpoint', 'Priority', 'get_priority', 'get_host_info',
'get_time_str', 'obj_from_dict', 'init_dist', 'get_dist_info',
'master_only', 'OPTIMIZER_BUILDERS', 'OPTIMIZERS',
'DefaultOptimizerConstructor', 'build_optimizer',
'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
Expand Down
11 changes: 6 additions & 5 deletions mmcv/runner/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (LoggerHook, MlflowLoggerHook, NeptuneLoggerHook,
PaviLoggerHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook
from .momentum_updater import MomentumUpdaterHook
Expand All @@ -21,6 +21,7 @@
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'NeptuneLoggerHook', 'WandbLoggerHook', 'MomentumUpdaterHook',
'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook'
'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
'DistEvalHook', 'ProfilerHook'
]
3 changes: 2 additions & 1 deletion mmcv/runner/hooks/logger/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .base import LoggerHook
from .dvclive import DvcliveLoggerHook
from .mlflow import MlflowLoggerHook
from .neptune import NeptuneLoggerHook
from .pavi import PaviLoggerHook
Expand All @@ -10,5 +11,5 @@
__all__ = [
'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
'NeptuneLoggerHook'
'NeptuneLoggerHook', 'DvcliveLoggerHook'
]
58 changes: 58 additions & 0 deletions mmcv/runner/hooks/logger/dvclive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Open-MMLab. All rights reserved.
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook


@HOOKS.register_module()
class DvcliveLoggerHook(LoggerHook):
"""Class to log metrics with dvclive.
It requires `dvclive`_ to be installed.
Args:
path (str): Directory where dvclive will write TSV log files.
interval (int): Logging interval (every k iterations).
Default 10.
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
Default: True.
reset_flag (bool): Whether to clear the output buffer after logging.
Default: True.
by_epoch (bool): Whether EpochBasedRunner is used.
Default: True.
.. _dvclive:
https://dvc.org/doc/dvclive
"""

def __init__(self,
path,
interval=10,
ignore_last=True,
reset_flag=True,
by_epoch=True):

super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
self.path = path
self.import_dvclive()

def import_dvclive(self):
try:
import dvclive
except ImportError:
raise ImportError(
'Please run "pip install dvclive" to install dvclive')
self.dvclive = dvclive

@master_only
def before_run(self, runner):
self.dvclive.init(self.path)

@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner)
if tags:
for k, v in tags.items():
self.dvclive.log(k, v, step=self.get_iter(runner))
24 changes: 21 additions & 3 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from torch.nn.init import constant_
from torch.utils.data import DataLoader

from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook,
WandbLoggerHook, build_runner)
from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook,
IterTimerHook, MlflowLoggerHook, NeptuneLoggerHook,
PaviLoggerHook, WandbLoggerHook, build_runner)
from mmcv.runner.hooks.hook import HOOKS, Hook
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook,
Expand Down Expand Up @@ -920,6 +920,7 @@ def test_neptune_hook():
sys.modules['neptune.new'] = MagicMock()
runner = _build_demo_runner()
hook = NeptuneLoggerHook()

loader = DataLoader(torch.ones((5, 2)))

runner.register_hook(hook)
Expand All @@ -931,6 +932,23 @@ def test_neptune_hook():
hook.run.stop.assert_called_with()


def test_dvclive_hook(tmp_path):
sys.modules['dvclive'] = MagicMock()
runner = _build_demo_runner()

(tmp_path / 'dvclive').mkdir()
hook = DvcliveLoggerHook(str(tmp_path / 'dvclive'))
loader = DataLoader(torch.ones((5, 2)))

runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)

hook.dvclive.init.assert_called_with(str(tmp_path / 'dvclive'))
hook.dvclive.log.assert_called_with('momentum', 0.95, step=6)
hook.dvclive.log.assert_any_call('learning_rate', 0.02, step=6)


def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
Expand Down

0 comments on commit bdd7022

Please sign in to comment.