Skip to content

Commit

Permalink
Allow bool input for loggers (#897)
Browse files Browse the repository at this point in the history
* Allow bool input for loggers

* Convert earlier on

* Fix test case
  • Loading branch information
ngcgarcia authored Jan 23, 2024
1 parent 36fcb5e commit 4961436
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
15 changes: 5 additions & 10 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,21 +219,16 @@ def build_callback(


def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination:
kwargs_dict = {
k: v if isinstance(v, str) else om.to_container(v, resolve=True)
for k, v in kwargs.items()
}

if name == 'wandb':
return WandBLogger(**kwargs_dict)
return WandBLogger(**kwargs)
elif name == 'tensorboard':
return TensorboardLogger(**kwargs_dict)
return TensorboardLogger(**kwargs)
elif name == 'in_memory_logger':
return InMemoryLogger(**kwargs_dict)
return InMemoryLogger(**kwargs)
elif name == 'mlflow':
return MLFlowLogger(**kwargs_dict)
return MLFlowLogger(**kwargs)
elif name == 'inmemory':
return InMemoryLogger(**kwargs_dict)
return InMemoryLogger(**kwargs)
else:
raise ValueError(f'Not sure how to build logger: {name}')

Expand Down
3 changes: 2 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def main(cfg: DictConfig) -> Trainer:
logger_configs: Optional[DictConfig] = pop_config(cfg,
'loggers',
must_exist=False,
default_value=None)
default_value=None,
convert=True)
callback_configs: Optional[DictConfig] = pop_config(cfg,
'callbacks',
must_exist=False,
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@ def test_build_logger():
with pytest.raises(ValueError):
_ = build_logger('unknown', {})

logger_cfg = DictConfig({
logger_cfg = {
'project': 'foobar',
'init_kwargs': {
'config': {
'foo': 'bar',
}
}
})
}
wandb_logger = build_logger('wandb', logger_cfg) # type: ignore
assert isinstance(wandb_logger, WandBLogger)
assert wandb_logger.project == 'foobar'
Expand Down

0 comments on commit 4961436

Please sign in to comment.