Skip to content

Commit

Permalink
[Enhance] Refactor logger (#659)
Browse files Browse the repository at this point in the history
* [Enhance] Refactor logger

* fixed test

* make commit optional

* remove debug info

* fixed test
  • Loading branch information
xvjiarui authored Nov 23, 2020
1 parent dfa36df commit 987cb58
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 13 deletions.
7 changes: 0 additions & 7 deletions mmcv/runner/hooks/logger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,6 @@ def get_iter(self, runner, inner_iter=False):
current_iter = runner.iter + 1
return current_iter

def get_step(self, runner):
"""Get the total training step/epoch."""
if self.get_mode(runner) == 'val' and self.by_epoch:
return self.get_epoch(runner)
else:
return self.get_iter(runner)

def get_lr_tags(self, runner):
tags = {}
lrs = runner.current_lr()
Expand Down
2 changes: 1 addition & 1 deletion mmcv/runner/hooks/logger/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def before_run(self, runner):
def log(self, runner):
tags = self.get_loggable_tags(runner)
if tags:
self.mlflow.log_metrics(tags, step=self.get_step(runner))
self.mlflow.log_metrics(tags, step=self.get_iter(runner))

@master_only
def after_run(self, runner):
Expand Down
7 changes: 7 additions & 0 deletions mmcv/runner/hooks/logger/pavi.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def before_run(self, runner):
if self.add_graph:
self.writer.add_graph(runner.model)

def get_step(self, runner):
"""Get the total training step/epoch."""
if self.get_mode(runner) == 'val' and self.by_epoch:
return self.get_epoch(runner)
else:
return self.get_iter(runner)

@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner, add_mode=False)
Expand Down
4 changes: 2 additions & 2 deletions mmcv/runner/hooks/logger/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def log(self, runner):
tags = self.get_loggable_tags(runner, allow_text=True)
for tag, val in tags.items():
if isinstance(val, str):
self.writer.add_text(tag, val, self.get_step(runner))
self.writer.add_text(tag, val, self.get_iter(runner))
else:
self.writer.add_scalar(tag, val, self.get_step(runner))
self.writer.add_scalar(tag, val, self.get_iter(runner))

@master_only
def after_run(self, runner):
Expand Down
5 changes: 4 additions & 1 deletion mmcv/runner/hooks/logger/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ def __init__(self,
interval=10,
ignore_last=True,
reset_flag=True,
commit=True,
by_epoch=True):
super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
self.import_wandb()
self.init_kwargs = init_kwargs
self.commit = commit

def import_wandb(self):
try:
Expand All @@ -39,7 +41,8 @@ def before_run(self, runner):
def log(self, runner):
tags = self.get_loggable_tags(runner)
if tags:
self.wandb.log(tags, step=self.get_step(runner))
self.wandb.log(
tags, step=self.get_iter(runner), commit=self.commit)

@master_only
def after_run(self, runner):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_mlflow_hook(log_model):
{
'learning_rate': 0.02,
'momentum': 0.95
}, step=1)
}, step=6)
if log_model:
hook.mlflow_pytorch.log_model.assert_called_with(
runner.model, 'models')
Expand All @@ -369,7 +369,8 @@ def test_wandb_hook():
'learning_rate': 0.02,
'momentum': 0.95
},
step=1)
step=6,
commit=True)
hook.wandb.join.assert_called_with()


Expand Down

0 comments on commit 987cb58

Please sign in to comment.