diff --git a/CHANGELOG.md b/CHANGELOG.md index 242200f85..e0db9265c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Make native prediction intervals for DeepAR ([#761](https://github.com/tinkoff-ai/etna/pull/761)) - Make native prediction intervals for TFTModel ([#770](https://github.com/tinkoff-ai/etna/pull/770)) - Test cases for testing inference of models ([#794](https://github.com/tinkoff-ai/etna/pull/794)) -- +- Wandb.log to WandbLogger ([#816](https://github.com/tinkoff-ai/etna/pull/816)) ### Fixed - - Fix missing prophet in docker images ([#767](https://github.com/tinkoff-ai/etna/pull/767)) diff --git a/etna/loggers/wandb_logger.py b/etna/loggers/wandb_logger.py index 42ee464db..e2a2476fc 100644 --- a/etna/loggers/wandb_logger.py +++ b/etna/loggers/wandb_logger.py @@ -101,10 +101,10 @@ def log(self, msg: Union[str, Dict[str, Any]], **kwargs): Notes ----- - We log nothing via current method in wandb case. - Currently you could call ``wandb.log`` by hand if you need this. + We log dictionary to wandb only. """ - pass + if isinstance(msg, dict): + self.experiment.log(msg) def log_backtest_metrics( self, ts: "TSDataset", metrics_df: pd.DataFrame, forecast_df: pd.DataFrame, fold_info_df: pd.DataFrame diff --git a/tests/test_loggers/test_wandb_logger.py b/tests/test_loggers/test_wandb_logger.py new file mode 100644 index 000000000..3b8da2dac --- /dev/null +++ b/tests/test_loggers/test_wandb_logger.py @@ -0,0 +1,29 @@ +from unittest.mock import call +from unittest.mock import patch + +import pytest + +from etna.loggers import WandbLogger +from etna.loggers import tslogger as _tslogger + + +@pytest.fixture() +def tslogger(): + _tslogger.loggers = [] + yield _tslogger + _tslogger.loggers = [] + + +@patch("etna.loggers.wandb_logger.wandb") +def test_wandb_logger_log(wandb, tslogger): + wandb_logger = WandbLogger() + tslogger.add(wandb_logger) + tslogger.log("test") + tslogger.log({"MAE": 0}) + tslogger.log({"MAPE": 1.5}) + calls = [ + call({"MAE": 0}), + call({"MAPE": 1.5}), + ] + assert wandb.init.return_value.log.call_count == 2 + wandb.init.return_value.log.assert_has_calls(calls)