Skip to content

Commit

Permalink
Add Custom Dataset Training Support (#154)
Browse files Browse the repository at this point in the history
* renamed download-progress-bar as download

* added new download functions to init

* Added Btech data module

* Added btech tests

* Move split functions into a util module

* Modified mvtec

* added btech to get-datamodule

* fix typo in btech docstring

* update docstring

* cleanedup dataset download utils

* Address mypy

* modify config files and update readme.md

* Fix dataset path

* WiP: Created make_dataset function

* Renamed folder dataset into custom

* Added custom dataset tests

* updated config.yaml file to show custom dataset is available

* Added custom dataset to get_datamodule

* Address PR comments

* fix dataset path

* Debugging the ci

* Fixed folder dataset tests

* Added code quality checks back to the ci

* Added code coverage back to pre-merge tests
  • Loading branch information
samet-akcay authored Mar 24, 2022
1 parent 2dfa0a7 commit b03fb32
Show file tree
Hide file tree
Showing 12 changed files with 604 additions and 14 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,34 @@ where the currently available models are:
- [DFKDE](anomalib/models/dfkde)
- [GANomaly](anomalib/models/ganomaly)

### Custom Dataset
It is also possible to train on a custom folder dataset. To do so, `data` section in `config.yaml` is to be modified as follows:
```yaml
dataset:
name: <name-of-the-dataset>
format: folder
path: <path/to/folder/dataset>
normal: normal # name of the folder containing normal images.
abnormal: abnormal # name of the folder containing abnormal images.
task: segmentation # classification or segmentation
mask: <path/to/mask/annotations> #optional
extensions: null
split_ratio: 0.2 # ratio of the normal images that will be used to create a test split
seed: 0
image_size: 256
train_batch_size: 32
test_batch_size: 32
num_workers: 8
transform_config: null
create_validation_set: true
tiling:
apply: false
tile_size: null
stride: null
remove_border_count: 0
use_random_tiling: False
random_tile_count: 16
```
## Inference
Anomalib contains several tools that can be used to perform inference with a trained model. The script in [`tools/inference`](tools/inference.py) contains an example of how the inference tools can be used to generate a prediction for an input image.
Expand Down
5 changes: 4 additions & 1 deletion anomalib/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@ def get_configurable_parameters(
config = update_input_size_config(config)

# Project Configs
project_path = Path(config.project.path) / config.model.name / config.dataset.name / config.dataset.category
project_path = Path(config.project.path) / config.model.name / config.dataset.name
if config.dataset.format.lower() in ("btech", "mvtec"):
project_path = project_path / config.dataset.category

(project_path / "weights").mkdir(parents=True, exist_ok=True)
(project_path / "images").mkdir(parents=True, exist_ok=True)
config.project.path = str(project_path)
Expand Down
32 changes: 28 additions & 4 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning import LightningDataModule

from .btech import BTechDataModule
from .folder import FolderDataModule
from .inference import InferenceDataset
from .mvtec import MVTecDataModule

Expand All @@ -35,7 +36,7 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
"""
datamodule: LightningDataModule

if config.dataset.name.lower() == "mvtec":
if config.dataset.format.lower() == "mvtec":
datamodule = MVTecDataModule(
# TODO: Remove config values. IAAALD-211
root=config.dataset.path,
Expand All @@ -48,19 +49,36 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
transform_config=config.dataset.transform_config,
create_validation_set=config.dataset.create_validation_set,
)
elif config.dataset.name.lower() == "btech":
elif config.dataset.format.lower() == "btech":
datamodule = BTechDataModule(
# TODO: Remove config values. IAAALD-211
root=config.dataset.path,
category=config.dataset.category,
image_size=(config.dataset.image_size[0], config.dataset.image_size[0]),
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
transform_config=config.dataset.transform_config,
create_validation_set=config.dataset.create_validation_set,
)
elif config.dataset.format.lower() == "folder":
datamodule = FolderDataModule(
root=config.dataset.path,
normal=config.dataset.normal,
abnormal=config.dataset.abnormal,
task=config.dataset.task,
mask_dir=config.dataset.mask,
extensions=config.dataset.extensions,
split_ratio=config.dataset.split_ratio,
seed=config.dataset.seed,
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
num_workers=config.dataset.num_workers,
transform_config=config.dataset.transform_config,
create_validation_set=config.dataset.create_validation_set,
)
else:
raise ValueError(
"Unknown dataset! \n"
Expand All @@ -71,4 +89,10 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
return datamodule


__all__ = ["get_datamodule", "InferenceDataset"]
__all__ = [
"get_datamodule",
"BTechDataModule",
"FolderDataModule",
"InferenceDataset",
"MVTecDataModule",
]
Loading

0 comments on commit b03fb32

Please sign in to comment.