Skip to content

Commit

Permalink
[Feature] Add CustomDataset. (open-mmlab#738)
Browse files Browse the repository at this point in the history
* Add custom dataset and refactor ImageNet dataset

* Add default CLASSES for CIFAR dataset

* Add unit tests

* Imporve according to comments
  • Loading branch information
mzr1996 authored Mar 30, 2022
1 parent 6722bcc commit 7a022df
Show file tree
Hide file tree
Showing 7 changed files with 1,159 additions and 562 deletions.
3 changes: 2 additions & 1 deletion mmcls/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
build_dataset, build_sampler)
from .cifar import CIFAR10, CIFAR100
from .cub import CUB
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
KFoldDataset, RepeatDataset)
from .imagenet import ImageNet
Expand All @@ -18,5 +19,5 @@
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB'
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset'
]
17 changes: 13 additions & 4 deletions mmcls/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from abc import ABCMeta, abstractmethod
from os import PathLike
from typing import List

import mmcv
Expand All @@ -12,6 +14,13 @@
from .pipelines import Compose


def expanduser(path):
if isinstance(path, (str, PathLike)):
return osp.expanduser(path)
else:
return path


class BaseDataset(Dataset, metaclass=ABCMeta):
"""Base dataset.
Expand All @@ -34,11 +43,11 @@ def __init__(self,
ann_file=None,
test_mode=False):
super(BaseDataset, self).__init__()
self.ann_file = ann_file
self.data_prefix = data_prefix
self.test_mode = test_mode
self.data_prefix = expanduser(data_prefix)
self.pipeline = Compose(pipeline)
self.CLASSES = self.get_classes(classes)
self.ann_file = expanduser(ann_file)
self.test_mode = test_mode
self.data_infos = self.load_annotations()

@abstractmethod
Expand Down Expand Up @@ -106,7 +115,7 @@ def get_classes(cls, classes=None):

if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
class_names = mmcv.list_from_file(expanduser(classes))
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
Expand Down
22 changes: 22 additions & 0 deletions mmcls/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class CIFAR10(BaseDataset):
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
CLASSES = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
'horse', 'ship', 'truck'
]

def load_annotations(self):

Expand Down Expand Up @@ -131,3 +135,21 @@ class CIFAR100(CIFAR10):
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}
CLASSES = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee',
'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus',
'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab',
'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish',
'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man',
'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom',
'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
'table', 'tank', 'telephone', 'television', 'tiger', 'tractor',
'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale',
'willow_tree', 'wolf', 'woman', 'worm'
]
229 changes: 229 additions & 0 deletions mmcls/datasets/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import mmcv
import numpy as np
from mmcv import FileClient

from .base_dataset import BaseDataset
from .builder import DATASETS


def find_folders(root: str,
file_client: FileClient) -> Tuple[List[str], Dict[str, int]]:
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
Returns:
Tuple[List[str], Dict[str, int]]:
- folders: The name of sub folders under the root.
- folder_to_idx: The map from folder name to class idx.
"""
folders = list(
file_client.list_dir_or_file(
root,
list_dir=True,
list_file=False,
recursive=False,
))
folders.sort()
folder_to_idx = {folders[i]: i for i in range(len(folders))}
return folders, folder_to_idx


def get_samples(root: str, folder_to_idx: Dict[str, int],
is_valid_file: Callable, file_client: FileClient):
"""Make dataset by walking all images under a root.
Args:
root (string): root directory of folders
folder_to_idx (dict): the map from class name to class idx
is_valid_file (Callable): A function that takes path of a file
and check if the file is a valid sample file.
Returns:
Tuple[list, set]:
- samples: a list of tuple where each element is (image, class_idx)
- empty_folders: The folders don't have any valid files.
"""
samples = []
available_classes = set()

for folder_name in sorted(list(folder_to_idx.keys())):
_dir = file_client.join_path(root, folder_name)
files = list(
file_client.list_dir_or_file(
_dir,
list_dir=False,
list_file=True,
recursive=True,
))
for file in sorted(list(files)):
if is_valid_file(file):
path = file_client.join_path(folder_name, file)
item = (path, folder_to_idx[folder_name])
samples.append(item)
available_classes.add(folder_name)

empty_folders = set(folder_to_idx.keys()) - available_classes

return samples, empty_folders


@DATASETS.register_module()
class CustomDataset(BaseDataset):
"""Custom dataset for classification.
The dataset supports two kinds of annotation format.
1. An annotation file is provided, and each line indicates a sample:
The sample files: ::
data_prefix/
├── folder_1
│ ├── xxx.png
│ ├── xxy.png
│ └── ...
└── folder_2
├── 123.png
├── nsdf3.png
└── ...
The annotation file (the first column is the image path and the second
column is the index of category): ::
folder_1/xxx.png 0
folder_1/xxy.png 1
folder_2/123.png 5
folder_2/nsdf3.png 3
...
Please specify the name of categories by the argument ``classes``.
2. The samples are arranged in the specific way: ::
data_prefix/
├── class_x
│ ├── xxx.png
│ ├── xxy.png
│ └── ...
│ └── xxz.png
└── class_y
├── 123.png
├── nsdf3.png
├── ...
└── asd932_.png
If the ``ann_file`` is specified, the dataset will be generated by the
first way, otherwise, try the second way.
Args:
data_prefix (str): The path of data directory.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
classes (str | Sequence[str], optional): Specify names of classes.
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, use ``cls.CLASSES`` or the names of sub folders
(If use the second way to arrange samples).
Defaults to None.
ann_file (str, optional): The annotation file. If is string, read
samples paths from the ann_file. If is None, find samples in
``data_prefix``. Defaults to None.
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
test_mode (bool): In train mode or test mode. It's only a mark and
won't be used in this class. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
If None, automatically inference from the specified path.
Defaults to None.
"""

def __init__(self,
data_prefix: str,
pipeline: Sequence = (),
classes: Union[str, Sequence[str], None] = None,
ann_file: Optional[str] = None,
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
'.bmp', '.pgm', '.tif'),
test_mode: bool = False,
file_client_args: Optional[dict] = None):
self.extensions = tuple(set([i.lower() for i in extensions]))
self.file_client_args = file_client_args

super().__init__(
data_prefix=data_prefix,
pipeline=pipeline,
classes=classes,
ann_file=ann_file,
test_mode=test_mode)

def _find_samples(self):
"""find samples from ``data_prefix``."""
file_client = FileClient.infer_client(self.file_client_args,
self.data_prefix)
classes, folder_to_idx = find_folders(self.data_prefix, file_client)
samples, empty_classes = get_samples(
self.data_prefix,
folder_to_idx,
is_valid_file=self.is_valid_file,
file_client=file_client,
)

if len(samples) == 0:
raise RuntimeError(
f'Found 0 files in subfolders of: {self.data_prefix}. '
f'Supported extensions are: {",".join(self.extensions)}')

if self.CLASSES is not None:
assert len(self.CLASSES) == len(classes), \
f"The number of subfolders ({len(classes)}) doesn't match " \
f'the number of specified classes ({len(self.CLASSES)}). ' \
'Please check the data folder.'
else:
self.CLASSES = classes

if empty_classes:
warnings.warn(
'Found no valid file in the folder '
f'{", ".join(empty_classes)}. '
f"Supported extensions are: {', '.join(self.extensions)}",
UserWarning)

self.folder_to_idx = folder_to_idx

return samples

def load_annotations(self):
"""Load image paths and gt_labels."""
if self.ann_file is None:
samples = self._find_samples()
elif isinstance(self.ann_file, str):
lines = mmcv.list_from_file(
self.ann_file, file_client_args=self.file_client_args)
samples = [x.strip().rsplit(' ', 1) for x in lines]
else:
raise TypeError('ann_file must be a str or None')

data_infos = []
for filename, gt_label in samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos

def is_valid_file(self, filename: str) -> bool:
"""Check if a file is a valid sample."""
return filename.lower().endswith(self.extensions)
Loading

0 comments on commit 7a022df

Please sign in to comment.