Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add CustomDataset #738

Merged
merged 4 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmcls/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
build_dataset, build_sampler)
from .cifar import CIFAR10, CIFAR100
from .cub import CUB
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
KFoldDataset, RepeatDataset)
from .imagenet import ImageNet
Expand All @@ -18,5 +19,5 @@
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB'
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset'
]
17 changes: 13 additions & 4 deletions mmcls/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from abc import ABCMeta, abstractmethod
from os import PathLike
from typing import List

import mmcv
Expand All @@ -12,6 +14,13 @@
from .pipelines import Compose


def expanduser(path):
if isinstance(path, (str, PathLike)):
return osp.expanduser(path)
else:
return path


class BaseDataset(Dataset, metaclass=ABCMeta):
"""Base dataset.

Expand All @@ -34,11 +43,11 @@ def __init__(self,
ann_file=None,
test_mode=False):
super(BaseDataset, self).__init__()
self.ann_file = ann_file
self.data_prefix = data_prefix
self.test_mode = test_mode
self.data_prefix = expanduser(data_prefix)
self.pipeline = Compose(pipeline)
self.CLASSES = self.get_classes(classes)
self.ann_file = expanduser(ann_file)
self.test_mode = test_mode
self.data_infos = self.load_annotations()

@abstractmethod
Expand Down Expand Up @@ -106,7 +115,7 @@ def get_classes(cls, classes=None):

if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
class_names = mmcv.list_from_file(expanduser(classes))
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
Expand Down
22 changes: 22 additions & 0 deletions mmcls/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class CIFAR10(BaseDataset):
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
CLASSES = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
'horse', 'ship', 'truck'
]

def load_annotations(self):

Expand Down Expand Up @@ -131,3 +135,21 @@ class CIFAR100(CIFAR10):
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}
CLASSES = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the meta info of CLASSES is in our code, it is so long, especially in 'imagenet.py'. It may be better to create a metafile.bin that saves all the CLASSES info. In that way, the code will be purer and users may read our code easily.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think so too, but move the categories info to another file may cause unexpected problem, especially for deployment.
The code of ImageNet is short, which is only a CustomDataset with preset attributes. I think we can keep it.

'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus',
'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab',
'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish',
'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man',
'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom',
'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
'table', 'tank', 'telephone', 'television', 'tiger', 'tractor',
'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale',
'willow_tree', 'wolf', 'woman', 'worm'
]
229 changes: 229 additions & 0 deletions mmcls/datasets/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import mmcv
import numpy as np
from mmcv import FileClient

from .base_dataset import BaseDataset
from .builder import DATASETS


def find_folders(root: str,
file_client: FileClient) -> Tuple[List[str], Dict[str, int]]:
"""Find classes by folders under a root.

Args:
root (string): root directory of folders

Returns:
Tuple[List[str], Dict[str, int]]:

- folders: The name of sub folders under the root.
- folder_to_idx: The map from folder name to class idx.
"""
folders = list(
file_client.list_dir_or_file(
root,
list_dir=True,
list_file=False,
recursive=False,
))
folders.sort()
folder_to_idx = {folders[i]: i for i in range(len(folders))}
return folders, folder_to_idx


def get_samples(root: str, folder_to_idx: Dict[str, int],
is_valid_file: Callable, file_client: FileClient):
"""Make dataset by walking all images under a root.

Args:
root (string): root directory of folders
folder_to_idx (dict): the map from class name to class idx
is_valid_file (Callable): A function that takes path of a file
and check if the file is a valid sample file.

Returns:
Tuple[list, set]:

- samples: a list of tuple where each element is (image, class_idx)
- empty_folders: The folders don't have any valid files.
"""
samples = []
available_classes = set()

for folder_name in sorted(list(folder_to_idx.keys())):
_dir = file_client.join_path(root, folder_name)
files = list(
file_client.list_dir_or_file(
_dir,
list_dir=False,
list_file=True,
recursive=True,
))
for file in sorted(list(files)):
if is_valid_file(file):
path = file_client.join_path(folder_name, file)
item = (path, folder_to_idx[folder_name])
samples.append(item)
available_classes.add(folder_name)

empty_folders = set(folder_to_idx.keys()) - available_classes

return samples, empty_folders


@DATASETS.register_module()
class CustomDataset(BaseDataset):
"""Custom dataset for classification.

The dataset supports two kinds of annotation format.

1. An annotation file is provided, and each line indicates a sample:

The sample files: ::

data_prefix/
├── folder_1
│ ├── xxx.png
│ ├── xxy.png
│ └── ...
└── folder_2
├── 123.png
├── nsdf3.png
└── ...

The annotation file (the first column is the image path and the second
column is the index of category): ::

folder_1/xxx.png 0
folder_1/xxy.png 1
folder_2/123.png 5
mzr1996 marked this conversation as resolved.
Show resolved Hide resolved
folder_2/nsdf3.png 3
...

Please specify the name of categories by the argument ``classes``.

2. The samples are arranged in the specific way: ::

data_prefix/
├── class_x
│ ├── xxx.png
│ ├── xxy.png
│ └── ...
│ └── xxz.png
└── class_y
├── 123.png
├── nsdf3.png
├── ...
└── asd932_.png

If the ``ann_file`` is specified, the dataset will be generated by the
first way, otherwise, try the second way.

Args:
data_prefix (str): The path of data directory.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
classes (str | Sequence[str], optional): Specify names of classes.

- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, use ``cls.CLASSES`` or the names of sub folders
(If use the second way to arrange samples).

Defaults to None.
ann_file (str, optional): The annotation file. If is string, read
samples paths from the ann_file. If is None, find samples in
``data_prefix``. Defaults to None.
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
test_mode (bool): In train mode or test mode. It's only a mark and
won't be used in this class. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
If None, automatically inference from the specified path.
Defaults to None.
"""

def __init__(self,
data_prefix: str,
pipeline: Sequence = (),
classes: Union[str, Sequence[str], None] = None,
ann_file: Optional[str] = None,
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
'.bmp', '.pgm', '.tif'),
test_mode: bool = False,
file_client_args: Optional[dict] = None):
self.extensions = tuple(set([i.lower() for i in extensions]))
self.file_client_args = file_client_args

super().__init__(
data_prefix=data_prefix,
pipeline=pipeline,
classes=classes,
ann_file=ann_file,
test_mode=test_mode)

def _find_samples(self):
"""find samples from ``data_prefix``."""
file_client = FileClient.infer_client(self.file_client_args,
self.data_prefix)
classes, folder_to_idx = find_folders(self.data_prefix, file_client)
samples, empty_classes = get_samples(
self.data_prefix,
folder_to_idx,
is_valid_file=self.is_valid_file,
file_client=file_client,
)

if len(samples) == 0:
raise RuntimeError(
f'Found 0 files in subfolders of: {self.data_prefix}. '
f'Supported extensions are: {",".join(self.extensions)}')

if self.CLASSES is not None:
assert len(self.CLASSES) == len(classes), \
f"The number of subfolders ({len(classes)}) doesn't match " \
f'the number of specified classes ({len(self.CLASSES)}). ' \
'Please check the data folder.'
else:
self.CLASSES = classes

if empty_classes:
warnings.warn(
'Found no valid file in the folder '
f'{", ".join(empty_classes)}. '
f"Supported extensions are: {', '.join(self.extensions)}",
UserWarning)

self.folder_to_idx = folder_to_idx

return samples

def load_annotations(self):
"""Load image paths and gt_labels."""
if self.ann_file is None:
samples = self._find_samples()
elif isinstance(self.ann_file, str):
lines = mmcv.list_from_file(
self.ann_file, file_client_args=self.file_client_args)
samples = [x.strip().rsplit(' ', 1) for x in lines]
else:
raise TypeError('ann_file must be a str or None')

data_infos = []
for filename, gt_label in samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos

def is_valid_file(self, filename: str) -> bool:
"""Check if a file is a valid sample."""
return filename.lower().endswith(self.extensions)
Loading