Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature : config file dump to wandb server #1471

Closed
wants to merge 12 commits into from
45 changes: 43 additions & 2 deletions mmcv/runner/hooks/logger/wandb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import shutil
from distutils.dir_util import copy_tree

from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
Expand All @@ -14,13 +18,28 @@ def __init__(self,
reset_flag=False,
commit=True,
by_epoch=True,
with_step=True):
with_step=True,
config_path=None):
hyun06000 marked this conversation as resolved.
Show resolved Hide resolved

super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)

"""
Args:
with_step (bool): whether making a log in each step or not.
Default: True.
config_path (str, optional): The path of the final config of each
experiment. It can be either a path of the final config file or
a directory of config files. The config is uploaded to wandb
server if it is not None. Default: None.
`New in version 1.3.18.`
"""
hyun06000 marked this conversation as resolved.
Show resolved Hide resolved

self.import_wandb()
self.init_kwargs = init_kwargs
self.commit = commit
self.with_step = with_step
self.config_path = config_path

def import_wandb(self):
try:
Expand All @@ -39,7 +58,29 @@ def before_run(self, runner):
self.wandb.init(**self.init_kwargs)
else:
self.wandb.init()


if self.config_path is not None:
if os.path.isdir(self.config_path):
copy_tree(self.config_path, self.wandb.run.dir)
hyun06000 marked this conversation as resolved.
Show resolved Hide resolved
for path_under_wandb, _, _ in os.walk(self.wandb.run.dir):
self.wandb.save(
glob_str=os.path.join(path_under_wandb,'*'),
base_path=self.wandb.run.dir,
policy='now'
)
else:
if os.path.isfile(self.config_path):
shutil.copy2(self.config_path, self.wandb.run.dir)
hyun06000 marked this conversation as resolved.
Show resolved Hide resolved
self.wandb.save(
glob_str=os.path.join(self.wandb.run.dir,'*'),
base_path=self.wandb.run.dir,
policy='now'
)
else:
raise FileNotFoundError(
"No such file or directory: " + self.config_path
)
hyun06000 marked this conversation as resolved.
Show resolved Hide resolved

@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner)
Expand Down
12 changes: 10 additions & 2 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,9 +1118,17 @@ def test_wandb_hook():

runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)


hook.wandb.init.assert_called_with()

hook.wandb.run.dir.return_value = runner.work_dir
hook.wandb.save.assert_called_with(
glob_str=hook.wandb.run.dir + '/*',
base_path=hook.wandb.run.dir,
policy='now'
)
shutil.rmtree(runner.work_dir)

hook.wandb.log.assert_called_with({
'learning_rate': 0.02,
'momentum': 0.95
Expand Down