Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datamodules] Merge feature branch #822

Merged
merged 39 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b21045b
New datamodules design (#572)
djdameln Oct 31, 2022
f66c84e
Video Datamodules (#676)
djdameln Nov 18, 2022
aac5a47
Update lightning_inference.py
djdameln Nov 29, 2022
ab6cb57
merge main
djdameln Dec 5, 2022
cb06714
Make `val split ratio` configurable (#760)
djdameln Dec 5, 2022
045d77f
Add support for Detection task type (#732)
djdameln Dec 6, 2022
ccec2f6
[Datamodules] Update deprecation messages (#764)
djdameln Dec 7, 2022
d210feb
Improve image source parsing for Folder dataset (#784)
djdameln Dec 13, 2022
67462e7
Synthetic anomaly for testing and validation (#634)
djdameln Dec 14, 2022
8141f2f
merge main
djdameln Dec 16, 2022
8601330
Bugfixes for Datamodules feature branch (#800)
djdameln Dec 19, 2022
663692e
Deprecate PreProcessor (#795)
djdameln Dec 19, 2022
57d3b4e
[Datamodules] Fix bug in bbox score to image score conversion (#803)
djdameln Dec 20, 2022
192ba94
Improve handling of `test_split_mode='none'` and `val_split_mode='non…
djdameln Dec 20, 2022
b21f12c
fix to float transform
djdameln Dec 27, 2022
ced7bc9
Detection improvements (#820)
djdameln Dec 29, 2022
690cb1b
merge main
djdameln Dec 29, 2022
4cf8577
update changelog
djdameln Dec 29, 2022
89661ba
update csflow config to new format
djdameln Dec 29, 2022
9114c7d
remove unused imports
djdameln Dec 29, 2022
1c903f4
line length
djdameln Dec 29, 2022
dbba22a
suppress bandit warnings
djdameln Jan 2, 2023
451caf4
use torch rng in augmenter
djdameln Jan 2, 2023
e457b5e
use tuple instead of list
djdameln Jan 5, 2023
68ef76b
add missing params to dosctring
djdameln Jan 5, 2023
1d23942
add missing licence information
djdameln Jan 5, 2023
d2fda44
COLS -> COLUMNS
djdameln Jan 5, 2023
0139d62
typing and variable naming
djdameln Jan 5, 2023
8914fa1
remove duplicate parameter in docstring
djdameln Jan 5, 2023
6b2dcc5
im_dir -> image_dir
djdameln Jan 5, 2023
2d24ed5
typing and docstring
djdameln Jan 5, 2023
6d39434
typing
djdameln Jan 5, 2023
6e3816f
ValSplitMode -> ValidationSplitMode
djdameln Jan 5, 2023
fad21b1
add missing licence
djdameln Jan 5, 2023
8df22c6
rename variable
djdameln Jan 5, 2023
ced0342
remove empty comment
djdameln Jan 5, 2023
96f9b5e
remove unused class attribute
djdameln Jan 5, 2023
0904529
[Detection] Compute box score when generating boxes from masks (#828)
djdameln Jan 6, 2023
a67c21b
revert val_split_mode -> validation_split_mode
djdameln Jan 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Add Synthetic anomalous dataset for validation and testing (https://github.com/openvinotoolkit/anomalib/pull/822)
- Add Detection task type support (https://github.com/openvinotoolkit/anomalib/pull/822)
- Add UCSDped and Avenue dataset implementation (https://github.com/openvinotoolkit/anomalib/pull/822)
- Add base classes for video dataset and video datamodule (https://github.com/openvinotoolkit/anomalib/pull/822)
- Add base classes for image dataset and image dataModule (https://github.com/openvinotoolkit/anomalib/pull/822)
- ✨ Add CSFlow model (<https://github.com/openvinotoolkit/anomalib/pull/657>)
- Log loss for existing trainable models (<https://github.com/openvinotoolkit/anomalib/pull/804>)
- Add section for community project (<https://github.com/openvinotoolkit/anomalib/pull/768>)
Expand All @@ -19,6 +24,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Changed

- Make input image normalization and center cropping configurable from config (https://github.com/openvinotoolkit/anomalib/pull/822)
- Improve flexibility and configurability of subset splitting (https://github.com/openvinotoolkit/anomalib/pull/822)
- Switch to new datamodules design (https://github.com/openvinotoolkit/anomalib/pull/822)
- Make normalization and center cropping configurable through config (<https://github.com/openvinotoolkit/anomalib/pull/795>)
- Switch to new [changelog format](https://keepachangelog.com/en/1.0.0/). (<https://github.com/openvinotoolkit/anomalib/pull/777>)
- Rename feature to task (<https://github.com/openvinotoolkit/anomalib/pull/769>)
- make device configurable in OpenVINO inference (<https://github.com/openvinotoolkit/anomalib/pull/755>)
Expand All @@ -32,6 +41,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated PreProcessor class (<https://github.com/openvinotoolkit/anomalib/pull/795>)
- Deprecate OptimalF1 metric in favor of AnomalyScoreThreshold and F1Score (<https://github.com/openvinotoolkit/anomalib/pull/796>)

### Fixed
Expand Down
137 changes: 125 additions & 12 deletions anomalib/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from omegaconf import DictConfig, ListConfig, OmegaConf

from anomalib.data.utils import TestSplitMode, ValSplitMode


def _get_now_str(timestamp: float) -> str:
"""Standard format for datetimes is defined here."""
Expand All @@ -32,11 +34,27 @@ def update_input_size_config(config: Union[DictConfig, ListConfig]) -> Union[Dic
Returns:
Union[DictConfig, ListConfig]: Configurable parameters with updated values
"""
# handle image size
if isinstance(config.dataset.image_size, int):
config.dataset.image_size = (config.dataset.image_size,) * 2

config.model.input_size = config.dataset.image_size
# Image size: Ensure value is in the form [height, width]
image_size = config.dataset.get("image_size")
if isinstance(image_size, int):
config.dataset.image_size = (image_size,) * 2
elif isinstance(image_size, ListConfig):
assert len(image_size) == 2, "image_size must be a single integer or tuple of length 2 for width and height."
else:
raise ValueError(f"image_size must be either int or ListConfig, got {type(image_size)}")

# Center crop: Ensure value is in the form [height, width], and update input_size
center_crop = config.dataset.get("center_crop")
if center_crop is None:
config.model.input_size = config.dataset.image_size
elif isinstance(center_crop, int):
config.dataset.center_crop = (center_crop,) * 2
config.model.input_size = config.dataset.center_crop
elif isinstance(center_crop, ListConfig):
assert len(center_crop) == 2, "center_crop must be a single integer or tuple of length 2 for width and height."
config.model.input_size = center_crop
else:
raise ValueError(f"center_crop must be either int or ListConfig, got {type(center_crop)}")

if "tiling" in config.dataset.keys() and config.dataset.tiling.apply:
if isinstance(config.dataset.tiling.tile_size, int):
Expand Down Expand Up @@ -109,6 +127,78 @@ def update_multi_gpu_training_config(config: Union[DictConfig, ListConfig]) -> U
return config


def update_datasets_config(config: Union[DictConfig, ListConfig]) -> Union[DictConfig, ListConfig]:
"""Updates the dataset section of the config.

Args:
config (Union[DictConfig, ListConfig]): Configurable parameters for the current run.

Returns:
Union[DictConfig, ListConfig]: Updated config
"""
if "format" not in config.dataset.keys():
config.dataset.format = "mvtec"

if "create_validation_set" in config.dataset.keys():
warn(
DeprecationWarning(
"The 'create_validation_set' parameter is deprecated and will be removed in a future release. Please "
"use 'validation_split_mode' instead."
)
)
config.dataset.val_split_mode = "from_test" if config.dataset.create_validation_set else "same_as_test"

if "test_batch_size" in config.dataset.keys():
warn(
DeprecationWarning(
"The 'test_batch_size' parameter is deprecated and will be removed in a future release. Please use "
"'eval_batch_size' instead."
)
)
config.dataset.eval_batch_size = config.dataset.test_batch_size

if "transform_config" in config.dataset.keys() and "val" in config.dataset.transform_config.keys():
warn(
DeprecationWarning(
"The 'transform_config.val' parameter is deprecated and will be removed in a future release. Please "
"use 'transform_config.eval' instead."
)
)
config.dataset.transform_config.eval = config.dataset.transform_config.val

config = update_input_size_config(config)

if "clip_length_in_frames" in config.dataset.keys() and config.dataset.clip_length_in_frames > 1:
warn(
"Anomalib's models and visualizer are currently not compatible with video datasets with a clip length > 1. "
"Custom changes to these modules will be needed to prevent errors and/or unpredictable behaviour."
)

if config.dataset.format == "folder" and "split_ratio" in config.dataset.keys():
warn(
DeprecationWarning(
"The 'split_ratio' parameter is deprecated and will be removed in a future release. Please use "
"'test_split_ratio' instead."
)
)
config.dataset.test_split_ratio = config.dataset.split_ratio

if config.dataset.get("test_split_mode") == TestSplitMode.NONE and config.dataset.get("val_split_mode") in [
ValSplitMode.SAME_AS_TEST,
ValSplitMode.FROM_TEST,
djdameln marked this conversation as resolved.
Show resolved Hide resolved
]:
warn(
f"val_split_mode {config.dataset.val_split_mode} not allowed for test_split_mode = 'none'. "
"Setting val_split_mode to 'none'."
)
config.dataset.val_split_mode = ValSplitMode.NONE

if config.dataset.get("val_split_mode") == ValSplitMode.NONE and config.trainer.limit_val_batches != 0.0:
warn("Running without validation set. Setting trainer.limit_val_batches to 0.")
config.trainer.limit_val_batches = 0.0
return config


def get_configurable_parameters(
model_name: Optional[str] = None,
config_path: Optional[Union[Path, str]] = None,
Expand Down Expand Up @@ -151,15 +241,24 @@ def get_configurable_parameters(
"(`null` in the YAML file) or remove the `seed` key from the YAML file."
)

# Dataset Configs
if "format" not in config.dataset.keys():
config.dataset.format = "mvtec"

config = update_datasets_config(config)
config = update_input_size_config(config)

# Project Configs
project_path = Path(config.project.path) / config.model.name / config.dataset.name

if config.dataset.format == "folder":
if "mask" in config.dataset:
warn(
DeprecationWarning(
"mask will be deprecated in favor of mask_dir in config.dataset in a future release."
)
)
config.dataset.mask_dir = config.dataset.mask
if "path" in config.dataset:
warn(DeprecationWarning("path will be deprecated in favor of root in config.dataset in a future release."))
config.dataset.root = config.dataset.path

# add category subfolder if needed
if config.dataset.format.lower() in ("btech", "mvtec"):
project_path = project_path / config.dataset.category
Expand Down Expand Up @@ -196,15 +295,29 @@ def get_configurable_parameters(
if "metrics" in config.keys():
# NOTE: Deprecate this after v0.4.0.
if "adaptive" in config.metrics.threshold.keys():
warn("adaptive will be deprecated in favor of method in config.metrics.threshold in v0.4.0.")
warn(
DeprecationWarning(
"adaptive will be deprecated in favor of method in config.metrics.threshold in a future release"
)
)
config.metrics.threshold.method = "adaptive" if config.metrics.threshold.adaptive else "manual"
if "image_default" in config.metrics.threshold.keys():
warn("image_default will be deprecated in favor of manual_image in config.metrics.threshold in v0.4.0.")
warn(
DeprecationWarning(
"image_default will be deprecated in favor of manual_image in config.metrics.threshold in a future "
"release."
)
)
config.metrics.threshold.manual_image = (
None if config.metrics.threshold.adaptive else config.metrics.threshold.image_default
)
if "pixel_default" in config.metrics.threshold.keys():
warn("pixel_default will be deprecated in favor of manual_pixel in config.metrics.threshold in v0.4.0.")
warn(
DeprecationWarning(
"pixel_default will be deprecated in favor of manual_pixel in config.metrics.threshold in a future "
"release."
)
)
config.metrics.threshold.manual_pixel = (
None if config.metrics.threshold.adaptive else config.metrics.threshold.pixel_default
)
Expand Down
98 changes: 78 additions & 20 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@
from typing import Union

from omegaconf import DictConfig, ListConfig
from pytorch_lightning import LightningDataModule

from .avenue import Avenue
from .base import AnomalibDataModule, AnomalibDataset
from .btech import BTech
from .folder import Folder
from .inference import InferenceDataset
from .mvtec import MVTec
from .task_type import TaskType
from .ucsd_ped import UCSDped

logger = logging.getLogger(__name__)


def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule:
def get_datamodule(config: Union[DictConfig, ListConfig]) -> AnomalibDataModule:
"""Get Anomaly Datamodule.

Args:
Expand All @@ -28,56 +31,106 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
"""
logger.info("Loading the datamodule")

datamodule: LightningDataModule
datamodule: AnomalibDataModule

# convert center crop to tuple
center_crop = config.dataset.get("center_crop")
if center_crop is not None:
center_crop = (center_crop[0], center_crop[1])

if config.dataset.format.lower() == "mvtec":
datamodule = MVTec(
# TODO: Remove config values. IAAALD-211
root=config.dataset.path,
category=config.dataset.category,
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
center_crop=center_crop,
normalization=config.dataset.normalization,
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
eval_batch_size=config.dataset.eval_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
task=config.dataset.task,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
transform_config_eval=config.dataset.transform_config.eval,
test_split_mode=config.dataset.test_split_mode,
test_split_ratio=config.dataset.test_split_ratio,
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "btech":
datamodule = BTech(
# TODO: Remove config values. IAAALD-211
root=config.dataset.path,
category=config.dataset.category,
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
center_crop=center_crop,
normalization=config.dataset.normalization,
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
eval_batch_size=config.dataset.eval_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
task=config.dataset.task,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
transform_config_eval=config.dataset.transform_config.eval,
test_split_mode=config.dataset.test_split_mode,
test_split_ratio=config.dataset.test_split_ratio,
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "folder":
datamodule = Folder(
root=config.dataset.path,
root=config.dataset.root,
normal_dir=config.dataset.normal_dir,
abnormal_dir=config.dataset.abnormal_dir,
task=config.dataset.task,
normal_test_dir=config.dataset.normal_test_dir,
mask_dir=config.dataset.mask,
mask_dir=config.dataset.mask_dir,
extensions=config.dataset.extensions,
split_ratio=config.dataset.split_ratio,
seed=config.project.seed,
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
center_crop=center_crop,
normalization=config.dataset.normalization,
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
eval_batch_size=config.dataset.eval_batch_size,
num_workers=config.dataset.num_workers,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
transform_config_eval=config.dataset.transform_config.eval,
test_split_mode=config.dataset.test_split_mode,
test_split_ratio=config.dataset.test_split_ratio,
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "ucsdped":
datamodule = UCSDped(
root=config.dataset.path,
category=config.dataset.category,
task=config.dataset.task,
clip_length_in_frames=config.dataset.clip_length_in_frames,
frames_between_clips=config.dataset.frames_between_clips,
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
center_crop=center_crop,
normalization=config.dataset.normalization,
transform_config_train=config.dataset.transform_config.train,
transform_config_eval=config.dataset.transform_config.eval,
train_batch_size=config.dataset.train_batch_size,
eval_batch_size=config.dataset.eval_batch_size,
num_workers=config.dataset.num_workers,
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "avenue":
datamodule = Avenue(
root=config.dataset.path,
gt_dir=config.dataset.gt_dir,
task=config.dataset.task,
clip_length_in_frames=config.dataset.clip_length_in_frames,
frames_between_clips=config.dataset.frames_between_clips,
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
center_crop=center_crop,
normalization=config.dataset.normalization,
transform_config_train=config.dataset.transform_config.train,
transform_config_eval=config.dataset.transform_config.eval,
train_batch_size=config.dataset.train_batch_size,
eval_batch_size=config.dataset.eval_batch_size,
num_workers=config.dataset.num_workers,
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
else:
raise ValueError(
Expand All @@ -90,9 +143,14 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule


__all__ = [
"AnomalibDataset",
"AnomalibDataModule",
"get_datamodule",
"BTech",
"Folder",
"InferenceDataset",
"MVTec",
"Avenue",
"UCSDped",
"TaskType",
]
Loading