Skip to content

Commit

Permalink
Replace SaveToCSVCallback with PL CSVLogger (#198)
Browse files Browse the repository at this point in the history
* remove `save_to_csv` flag from config __init__

* Removed `save_to_csv` from config files

* Added csv option to logger in the config files

* Add csv tests to logger tests

* remove save to csv callback

* 🛠 Modify `get_loggers` function to also get csv logger.

* ✏️ Remove blank line.
  • Loading branch information
samet-akcay authored Apr 5, 2022
1 parent 6a23a30 commit c8ef5dd
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 105 deletions.
3 changes: 1 addition & 2 deletions anomalib/models/cflow/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ project:
seed: 0
path: ./results
log_images_to: [local]
logger: false
save_to_csv: false
logger: false # options: [tensorboard, wandb, csv] or combinations.

# PL Trainer Args. Don't add extra parameter here.
trainer:
Expand Down
3 changes: 1 addition & 2 deletions anomalib/models/dfkde/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ project:
seed: 42
path: ./results
log_images_to: []
logger: false
save_to_csv: false
logger: false # options: [tensorboard, wandb, csv] or combinations.

# PL Trainer Args. Don't add extra parameter here.
trainer:
Expand Down
3 changes: 1 addition & 2 deletions anomalib/models/dfm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ project:
seed: 42
path: ./results
log_images_to: []
logger: false
save_to_csv: false
logger: false # options: [tensorboard, wandb, csv] or combinations.

# PL Trainer Args. Don't add extra parameter here.
trainer:
Expand Down
3 changes: 1 addition & 2 deletions anomalib/models/ganomaly/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ project:
seed: 0
path: ./results
log_images_to: []
logger: false
save_to_csv: false
logger: false # options: [tensorboard, wandb, csv] or combinations.

optimization:
openvino:
Expand Down
3 changes: 1 addition & 2 deletions anomalib/models/padim/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ project:
seed: 42
path: ./results
log_images_to: ["local"]
logger: false
save_to_csv: false
logger: false # options: [tensorboard, wandb, csv] or combinations.

optimization:
openvino:
Expand Down
3 changes: 1 addition & 2 deletions anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ project:
seed: 0
path: ./results
log_images_to: [local]
logger: false
save_to_csv: false
logger: false # options: [tensorboard, wandb, csv] or combinations.

# PL Trainer Args. Don't add extra parameter here.
trainer:
Expand Down
3 changes: 1 addition & 2 deletions anomalib/models/stfpm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ project:
seed: 0
path: ./results
log_images_to: [local]
logger: false
save_to_csv: false
logger: false # options: [tensorboard, wandb, csv] or combinations.

optimization:
openvino:
Expand Down
6 changes: 0 additions & 6 deletions anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@
from .cdf_normalization import CdfNormalizationCallback
from .min_max_normalization import MinMaxNormalizationCallback
from .model_loader import LoadModelCallback
from .save_to_csv import SaveToCSVCallback
from .timer import TimerCallback
from .visualizer_callback import VisualizerCallback

__all__ = [
"LoadModelCallback",
"TimerCallback",
"VisualizerCallback",
"SaveToCSVCallback",
]


Expand Down Expand Up @@ -107,8 +105,4 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
)
)

if "save_to_csv" in config.project.keys():
if config.project.save_to_csv:
callbacks.append(SaveToCSVCallback())

return callbacks
51 changes: 0 additions & 51 deletions anomalib/utils/callbacks/save_to_csv.py

This file was deleted.

73 changes: 41 additions & 32 deletions anomalib/utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,26 @@


import os
from typing import Union
from typing import Iterable, List, Union

from omegaconf.dictconfig import DictConfig
from omegaconf.listconfig import ListConfig
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase

from .tensorboard import AnomalibTensorBoardLogger
from .wandb import AnomalibWandbLogger

__all__ = ["AnomalibTensorBoardLogger", "get_logger", "AnomalibWandbLogger"]
AVAILABLE_LOGGERS = ["tensorboard", "wandb"]
AVAILABLE_LOGGERS = ["tensorboard", "wandb", "csv"]


class UnknownLogger(Exception):
"""This is raised when the logger option in `config.yaml` file is set incorrectly."""


def get_logger(config: Union[DictConfig, ListConfig]) -> Union[LightningLoggerBase, bool]:
def get_logger(
config: Union[DictConfig, ListConfig]
) -> Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]:
"""Return a logger based on the choice of logger in the config file.
Args:
Expand All @@ -45,32 +47,39 @@ def get_logger(config: Union[DictConfig, ListConfig]) -> Union[LightningLoggerBa
Returns:
Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]: Logger
"""
logger: Union[LightningLoggerBase, bool]

if config.project.logger in [None, False]:
logger = False

elif config.project.logger == "tensorboard":
logger = AnomalibTensorBoardLogger(
name="Tensorboard Logs",
save_dir=os.path.join(config.project.path, "logs"),
)

elif config.project.logger == "wandb":
wandb_logdir = os.path.join(config.project.path, "logs")
os.makedirs(wandb_logdir, exist_ok=True)
logger = AnomalibWandbLogger(
project=config.dataset.name,
name=f"{config.dataset.category} {config.model.name}",
save_dir=wandb_logdir,
)

else:
raise UnknownLogger(
f"Unknown logger type: {config.project.logger}. "
f"Available loggers are: {AVAILABLE_LOGGERS}.\n"
f"To enable the logger, set `project.logger` to `true` or use one of available loggers in config.yaml\n"
f"To disable the logger, set `project.logger` to `false`."
)

return logger
return False

logger_list: List[LightningLoggerBase] = []
if isinstance(config.project.logger, str):
config.project.logger = [config.project.logger]

for logger in config.project.logger:
if logger == "tensorboard":
logger_list.append(
AnomalibTensorBoardLogger(
name="Tensorboard Logs",
save_dir=os.path.join(config.project.path, "logs"),
)
)
elif logger == "wandb":
wandb_logdir = os.path.join(config.project.path, "logs")
os.makedirs(wandb_logdir, exist_ok=True)
logger_list.append(
AnomalibWandbLogger(
project=config.dataset.name,
name=f"{config.dataset.category} {config.model.name}",
save_dir=wandb_logdir,
)
)
elif logger == "csv":
logger_list.append(CSVLogger(save_dir=os.path.join(config.project.path, "logs")))
else:
raise UnknownLogger(
f"Unknown logger type: {config.project.logger}. "
f"Available loggers are: {AVAILABLE_LOGGERS}.\n"
f"To enable the logger, set `project.logger` to `true` or use one of available loggers in config.yaml\n"
f"To disable the logger, set `project.logger` to `false`."
)

return logger_list
17 changes: 15 additions & 2 deletions tests/pre_merge/utils/loggers/test_get_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
from omegaconf import OmegaConf
from pytorch_lightning.loggers import CSVLogger

from anomalib.utils.loggers import (
AnomalibTensorBoardLogger,
Expand Down Expand Up @@ -46,12 +47,24 @@ def test_get_logger():
# get tensorboard
config.project.logger = "tensorboard"
logger = get_logger(config=config)
assert isinstance(logger, AnomalibTensorBoardLogger)
assert isinstance(logger[0], AnomalibTensorBoardLogger)

# get wandb logger
config.project.logger = "wandb"
logger = get_logger(config=config)
assert isinstance(logger, AnomalibWandbLogger)
assert isinstance(logger[0], AnomalibWandbLogger)

# get csv logger.
config.project.logger = "csv"
logger = get_logger(config=config)
assert isinstance(logger[0], CSVLogger)

# get multiple loggers
config.project.logger = ["tensorboard", "wandb", "csv"]
logger = get_logger(config=config)
assert isinstance(logger[0], AnomalibTensorBoardLogger)
assert isinstance(logger[1], AnomalibWandbLogger)
assert isinstance(logger[2], CSVLogger)

# raise unknown
with pytest.raises(UnknownLogger):
Expand Down

0 comments on commit c8ef5dd

Please sign in to comment.