Skip to content

NEW: wandb log to logger #816

Merged
merged 4 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Make native prediction intervals for DeepAR ([#761](https://github.com/tinkoff-ai/etna/pull/761))
- Make native prediction intervals for TFTModel ([#770](https://github.com/tinkoff-ai/etna/pull/770))
- Test cases for testing inference of models ([#794](https://github.com/tinkoff-ai/etna/pull/794))
-
- Wandb.log to WandbLogger ([#816](https://github.com/tinkoff-ai/etna/pull/816))
### Fixed
-
- Fix missing prophet in docker images ([#767](https://github.com/tinkoff-ai/etna/pull/767))
Expand Down
6 changes: 3 additions & 3 deletions etna/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def log(self, msg: Union[str, Dict[str, Any]], **kwargs):

Notes
-----
We log nothing via current method in wandb case.
Currently you could call ``wandb.log`` by hand if you need this.
We log dictionary to wandb only.
"""
pass
if isinstance(msg, dict):
martins0n marked this conversation as resolved.
Show resolved Hide resolved
self.experiment.log(msg)

def log_backtest_metrics(
self, ts: "TSDataset", metrics_df: pd.DataFrame, forecast_df: pd.DataFrame, fold_info_df: pd.DataFrame
Expand Down
29 changes: 29 additions & 0 deletions tests/test_loggers/test_wandb_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from unittest.mock import call
from unittest.mock import patch

import pytest

from etna.loggers import WandbLogger
from etna.loggers import tslogger as _tslogger


@pytest.fixture()
def tslogger():
_tslogger.loggers = []
yield _tslogger
_tslogger.loggers = []


@patch("etna.loggers.wandb_logger.wandb")
def test_wandb_logger_log(wandb, tslogger):
wandb_logger = WandbLogger()
tslogger.add(wandb_logger)
tslogger.log("test")
tslogger.log({"MAE": 0})
tslogger.log({"MAPE": 1.5})
calls = [
call({"MAE": 0}),
call({"MAPE": 1.5}),
]
assert wandb.init.return_value.log.call_count == 2
wandb.init.return_value.log.assert_has_calls(calls)