Skip to content

NEW: wandb log to logger #816

Merged
merged 4 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Make native prediction intervals for DeepAR ([#761](https://github.com/tinkoff-ai/etna/pull/761))
- Make native prediction intervals for TFTModel ([#770](https://github.com/tinkoff-ai/etna/pull/770))
- Test cases for testing inference of models ([#794](https://github.com/tinkoff-ai/etna/pull/794))
-
- Wandb.log to WandbLogger ([#816](https://github.com/tinkoff-ai/etna/pull/816))
### Fixed
-
- Fix missing prophet in docker images ([#767](https://github.com/tinkoff-ai/etna/pull/767))
Expand Down
6 changes: 3 additions & 3 deletions etna/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def log(self, msg: Union[str, Dict[str, Any]], **kwargs):

Notes
-----
We log nothing via current method in wandb case.
Currently you could call ``wandb.log`` by hand if you need this.
We log dictionary to wandb only.
"""
pass
if isinstance(msg, dict):
martins0n marked this conversation as resolved.
Show resolved Hide resolved
self.experiment.log(msg)

def log_backtest_metrics(
self, ts: "TSDataset", metrics_df: pd.DataFrame, forecast_df: pd.DataFrame, fold_info_df: pd.DataFrame
Expand Down
27 changes: 27 additions & 0 deletions tests/test_loggers/test_wandb_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from unittest.mock import call
from unittest.mock import patch

import pytest

from etna.loggers import WandbLogger
from etna.loggers import tslogger as _tslogger


@pytest.fixture()
def tslogger():
_tslogger.loggers = []
return _tslogger


@patch("etna.loggers.wandb_logger.wandb")
def test_wandb_logger_log(wandb, tslogger):
wandb_logger = WandbLogger()
tslogger.add(wandb_logger)
tslogger.log("test")
tslogger.log({"MAE": 0})
tslogger.log({"MAPE": 1.5})
calls = [
call({"MAE": 0}),
call({"MAPE": 1.5}),
]
wandb.init.return_value.log.assert_has_calls(calls)