Skip to content

Commit

Permalink
NEW: wandb log to logger (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
martins0n authored Jul 25, 2022
1 parent 66027b9 commit 17a7b73
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Make native prediction intervals for DeepAR ([#761](https://github.com/tinkoff-ai/etna/pull/761))
- Make native prediction intervals for TFTModel ([#770](https://github.com/tinkoff-ai/etna/pull/770))
- Test cases for testing inference of models ([#794](https://github.com/tinkoff-ai/etna/pull/794))
-
- Wandb.log to WandbLogger ([#816](https://github.com/tinkoff-ai/etna/pull/816))
### Fixed
-
- Fix missing prophet in docker images ([#767](https://github.com/tinkoff-ai/etna/pull/767))
Expand Down
6 changes: 3 additions & 3 deletions etna/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def log(self, msg: Union[str, Dict[str, Any]], **kwargs):
Notes
-----
We log nothing via current method in wandb case.
Currently you could call ``wandb.log`` by hand if you need this.
We log dictionary to wandb only.
"""
pass
if isinstance(msg, dict):
self.experiment.log(msg)

def log_backtest_metrics(
self, ts: "TSDataset", metrics_df: pd.DataFrame, forecast_df: pd.DataFrame, fold_info_df: pd.DataFrame
Expand Down
29 changes: 29 additions & 0 deletions tests/test_loggers/test_wandb_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from unittest.mock import call
from unittest.mock import patch

import pytest

from etna.loggers import WandbLogger
from etna.loggers import tslogger as _tslogger


@pytest.fixture()
def tslogger():
_tslogger.loggers = []
yield _tslogger
_tslogger.loggers = []


@patch("etna.loggers.wandb_logger.wandb")
def test_wandb_logger_log(wandb, tslogger):
wandb_logger = WandbLogger()
tslogger.add(wandb_logger)
tslogger.log("test")
tslogger.log({"MAE": 0})
tslogger.log({"MAPE": 1.5})
calls = [
call({"MAE": 0}),
call({"MAPE": 1.5}),
]
assert wandb.init.return_value.log.call_count == 2
wandb.init.return_value.log.assert_has_calls(calls)

1 comment on commit 17a7b73

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.