diff --git a/mmcls/datasets/__init__.py b/mmcls/datasets/__init__.py index 510e262cfe8..c71dd50a201 100644 --- a/mmcls/datasets/__init__.py +++ b/mmcls/datasets/__init__.py @@ -4,6 +4,7 @@ build_dataset, build_sampler) from .cifar import CIFAR10, CIFAR100 from .cub import CUB +from .custom import CustomDataset from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset, KFoldDataset, RepeatDataset) from .imagenet import ImageNet @@ -18,5 +19,5 @@ 'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset', 'DistributedSampler', 'ConcatDataset', 'RepeatDataset', 'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS', - 'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB' + 'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset' ] diff --git a/mmcls/datasets/base_dataset.py b/mmcls/datasets/base_dataset.py index 7924b4065da..fb6578ab181 100644 --- a/mmcls/datasets/base_dataset.py +++ b/mmcls/datasets/base_dataset.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import os.path as osp from abc import ABCMeta, abstractmethod +from os import PathLike from typing import List import mmcv @@ -12,6 +14,13 @@ from .pipelines import Compose +def expanduser(path): + if isinstance(path, (str, PathLike)): + return osp.expanduser(path) + else: + return path + + class BaseDataset(Dataset, metaclass=ABCMeta): """Base dataset. @@ -34,11 +43,11 @@ def __init__(self, ann_file=None, test_mode=False): super(BaseDataset, self).__init__() - self.ann_file = ann_file - self.data_prefix = data_prefix - self.test_mode = test_mode + self.data_prefix = expanduser(data_prefix) self.pipeline = Compose(pipeline) self.CLASSES = self.get_classes(classes) + self.ann_file = expanduser(ann_file) + self.test_mode = test_mode self.data_infos = self.load_annotations() @abstractmethod @@ -106,7 +115,7 @@ def get_classes(cls, classes=None): if isinstance(classes, str): # take it as a file path - class_names = mmcv.list_from_file(classes) + class_names = mmcv.list_from_file(expanduser(classes)) elif isinstance(classes, (tuple, list)): class_names = classes else: diff --git a/mmcls/datasets/cifar.py b/mmcls/datasets/cifar.py index 31440247b86..453b8d9d95f 100644 --- a/mmcls/datasets/cifar.py +++ b/mmcls/datasets/cifar.py @@ -40,6 +40,10 @@ class CIFAR10(BaseDataset): 'key': 'label_names', 'md5': '5ff9c542aee3614f3951f8cda6e48888', } + CLASSES = [ + 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', + 'horse', 'ship', 'truck' + ] def load_annotations(self): @@ -131,3 +135,21 @@ class CIFAR100(CIFAR10): 'key': 'fine_label_names', 'md5': '7973b15100ade9c7d40fb424638fde48', } + CLASSES = [ + 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', + 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', + 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', + 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', + 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', + 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', + 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', + 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', + 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', + 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', + 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', + 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', + 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', + 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', + 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', + 'willow_tree', 'wolf', 'woman', 'worm' + ] diff --git a/mmcls/datasets/custom.py b/mmcls/datasets/custom.py new file mode 100644 index 00000000000..61458f63bac --- /dev/null +++ b/mmcls/datasets/custom.py @@ -0,0 +1,229 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +from mmcv import FileClient + +from .base_dataset import BaseDataset +from .builder import DATASETS + + +def find_folders(root: str, + file_client: FileClient) -> Tuple[List[str], Dict[str, int]]: + """Find classes by folders under a root. + + Args: + root (string): root directory of folders + + Returns: + Tuple[List[str], Dict[str, int]]: + + - folders: The name of sub folders under the root. + - folder_to_idx: The map from folder name to class idx. + """ + folders = list( + file_client.list_dir_or_file( + root, + list_dir=True, + list_file=False, + recursive=False, + )) + folders.sort() + folder_to_idx = {folders[i]: i for i in range(len(folders))} + return folders, folder_to_idx + + +def get_samples(root: str, folder_to_idx: Dict[str, int], + is_valid_file: Callable, file_client: FileClient): + """Make dataset by walking all images under a root. + + Args: + root (string): root directory of folders + folder_to_idx (dict): the map from class name to class idx + is_valid_file (Callable): A function that takes path of a file + and check if the file is a valid sample file. + + Returns: + Tuple[list, set]: + + - samples: a list of tuple where each element is (image, class_idx) + - empty_folders: The folders don't have any valid files. + """ + samples = [] + available_classes = set() + + for folder_name in sorted(list(folder_to_idx.keys())): + _dir = file_client.join_path(root, folder_name) + files = list( + file_client.list_dir_or_file( + _dir, + list_dir=False, + list_file=True, + recursive=True, + )) + for file in sorted(list(files)): + if is_valid_file(file): + path = file_client.join_path(folder_name, file) + item = (path, folder_to_idx[folder_name]) + samples.append(item) + available_classes.add(folder_name) + + empty_folders = set(folder_to_idx.keys()) - available_classes + + return samples, empty_folders + + +@DATASETS.register_module() +class CustomDataset(BaseDataset): + """Custom dataset for classification. + + The dataset supports two kinds of annotation format. + + 1. An annotation file is provided, and each line indicates a sample: + + The sample files: :: + + data_prefix/ + ├── folder_1 + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + └── folder_2 + ├── 123.png + ├── nsdf3.png + └── ... + + The annotation file (the first column is the image path and the second + column is the index of category): :: + + folder_1/xxx.png 0 + folder_1/xxy.png 1 + folder_2/123.png 5 + folder_2/nsdf3.png 3 + ... + + Please specify the name of categories by the argument ``classes``. + + 2. The samples are arranged in the specific way: :: + + data_prefix/ + ├── class_x + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + │ └── xxz.png + └── class_y + ├── 123.png + ├── nsdf3.png + ├── ... + └── asd932_.png + + If the ``ann_file`` is specified, the dataset will be generated by the + first way, otherwise, try the second way. + + Args: + data_prefix (str): The path of data directory. + pipeline (Sequence[dict]): A list of dict, where each element + represents a operation defined in :mod:`mmcls.datasets.pipelines`. + Defaults to an empty tuple. + classes (str | Sequence[str], optional): Specify names of classes. + + - If is string, it should be a file path, and the every line of + the file is a name of a class. + - If is a sequence of string, every item is a name of class. + - If is None, use ``cls.CLASSES`` or the names of sub folders + (If use the second way to arrange samples). + + Defaults to None. + ann_file (str, optional): The annotation file. If is string, read + samples paths from the ann_file. If is None, find samples in + ``data_prefix``. Defaults to None. + extensions (Sequence[str]): A sequence of allowed extensions. Defaults + to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'). + test_mode (bool): In train mode or test mode. It's only a mark and + won't be used in this class. Defaults to False. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + If None, automatically inference from the specified path. + Defaults to None. + """ + + def __init__(self, + data_prefix: str, + pipeline: Sequence = (), + classes: Union[str, Sequence[str], None] = None, + ann_file: Optional[str] = None, + extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', + '.bmp', '.pgm', '.tif'), + test_mode: bool = False, + file_client_args: Optional[dict] = None): + self.extensions = tuple(set([i.lower() for i in extensions])) + self.file_client_args = file_client_args + + super().__init__( + data_prefix=data_prefix, + pipeline=pipeline, + classes=classes, + ann_file=ann_file, + test_mode=test_mode) + + def _find_samples(self): + """find samples from ``data_prefix``.""" + file_client = FileClient.infer_client(self.file_client_args, + self.data_prefix) + classes, folder_to_idx = find_folders(self.data_prefix, file_client) + samples, empty_classes = get_samples( + self.data_prefix, + folder_to_idx, + is_valid_file=self.is_valid_file, + file_client=file_client, + ) + + if len(samples) == 0: + raise RuntimeError( + f'Found 0 files in subfolders of: {self.data_prefix}. ' + f'Supported extensions are: {",".join(self.extensions)}') + + if self.CLASSES is not None: + assert len(self.CLASSES) == len(classes), \ + f"The number of subfolders ({len(classes)}) doesn't match " \ + f'the number of specified classes ({len(self.CLASSES)}). ' \ + 'Please check the data folder.' + else: + self.CLASSES = classes + + if empty_classes: + warnings.warn( + 'Found no valid file in the folder ' + f'{", ".join(empty_classes)}. ' + f"Supported extensions are: {', '.join(self.extensions)}", + UserWarning) + + self.folder_to_idx = folder_to_idx + + return samples + + def load_annotations(self): + """Load image paths and gt_labels.""" + if self.ann_file is None: + samples = self._find_samples() + elif isinstance(self.ann_file, str): + lines = mmcv.list_from_file( + self.ann_file, file_client_args=self.file_client_args) + samples = [x.strip().rsplit(' ', 1) for x in lines] + else: + raise TypeError('ann_file must be a str or None') + + data_infos = [] + for filename, gt_label in samples: + info = {'img_prefix': self.data_prefix} + info['img_info'] = {'filename': filename} + info['gt_label'] = np.array(gt_label, dtype=np.int64) + data_infos.append(info) + return data_infos + + def is_valid_file(self, filename: str) -> bool: + """Check if a file is a valid sample.""" + return filename.lower().endswith(self.extensions) diff --git a/mmcls/datasets/imagenet.py b/mmcls/datasets/imagenet.py index 9bfd31b0794..20483b6dcfb 100644 --- a/mmcls/datasets/imagenet.py +++ b/mmcls/datasets/imagenet.py @@ -1,68 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os +from typing import Optional, Sequence, Union -import numpy as np - -from .base_dataset import BaseDataset from .builder import DATASETS - - -def has_file_allowed_extension(filename, extensions): - """Checks if a file is an allowed extension. - - Args: - filename (string): path to a file - - Returns: - bool: True if the filename ends with a known image extension - """ - filename_lower = filename.lower() - return any(filename_lower.endswith(ext) for ext in extensions) - - -def find_folders(root): - """Find classes by folders under a root. - - Args: - root (string): root directory of folders - - Returns: - folder_to_idx (dict): the map from folder name to class idx - """ - folders = [ - d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d)) - ] - folders.sort() - folder_to_idx = {folders[i]: i for i in range(len(folders))} - return folder_to_idx - - -def get_samples(root, folder_to_idx, extensions): - """Make dataset by walking all images under a root. - - Args: - root (string): root directory of folders - folder_to_idx (dict): the map from class name to class idx - extensions (tuple): allowed extensions - - Returns: - samples (list): a list of tuple where each element is (image, label) - """ - samples = [] - root = os.path.expanduser(root) - for folder_name in sorted(list(folder_to_idx.keys())): - _dir = os.path.join(root, folder_name) - for _, _, fns in sorted(os.walk(_dir)): - for fn in sorted(fns): - if has_file_allowed_extension(fn, extensions): - path = os.path.join(folder_name, fn) - item = (path, folder_to_idx[folder_name]) - samples.append(item) - return samples +from .custom import CustomDataset @DATASETS.register_module() -class ImageNet(BaseDataset): +class ImageNet(CustomDataset): """`ImageNet `_ Dataset. This implementation is modified from @@ -1073,31 +1017,18 @@ class ImageNet(BaseDataset): 'toilet tissue, toilet paper, bathroom tissue' ] - def load_annotations(self): - if self.ann_file is None: - folder_to_idx = find_folders(self.data_prefix) - samples = get_samples( - self.data_prefix, - folder_to_idx, - extensions=self.IMG_EXTENSIONS) - if len(samples) == 0: - raise (RuntimeError('Found 0 files in subfolders of: ' - f'{self.data_prefix}. ' - 'Supported extensions are: ' - f'{",".join(self.IMG_EXTENSIONS)}')) - - self.folder_to_idx = folder_to_idx - elif isinstance(self.ann_file, str): - with open(self.ann_file) as f: - samples = [x.strip().rsplit(' ', 1) for x in f.readlines()] - else: - raise TypeError('ann_file must be a str or None') - self.samples = samples - - data_infos = [] - for filename, gt_label in self.samples: - info = {'img_prefix': self.data_prefix} - info['img_info'] = {'filename': filename} - info['gt_label'] = np.array(gt_label, dtype=np.int64) - data_infos.append(info) - return data_infos + def __init__(self, + data_prefix: str, + pipeline: Sequence = (), + classes: Union[str, Sequence[str], None] = None, + ann_file: Optional[str] = None, + test_mode: bool = False, + file_client_args: Optional[dict] = None): + super().__init__( + data_prefix=data_prefix, + pipeline=pipeline, + classes=classes, + ann_file=ann_file, + extensions=self.IMG_EXTENSIONS, + test_mode=test_mode, + file_client_args=file_client_args) diff --git a/mmcls/datasets/imagenet21k.py b/mmcls/datasets/imagenet21k.py index 6fc2eccf646..cbae98cf119 100644 --- a/mmcls/datasets/imagenet21k.py +++ b/mmcls/datasets/imagenet21k.py @@ -1,72 +1,106 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os +import gc +import pickle import warnings -from typing import List +from typing import List, Optional, Sequence, Tuple, Union import numpy as np -from mmcv.utils import scandir -from .base_dataset import BaseDataset from .builder import DATASETS -from .imagenet import find_folders - - -class ImageInfo(): - """class to store image info, using slots will save memory than using - dict.""" - __slots__ = ['path', 'gt_label'] - - def __init__(self, path, gt_label): - self.path = path - self.gt_label = gt_label +from .custom import CustomDataset @DATASETS.register_module() -class ImageNet21k(BaseDataset): +class ImageNet21k(CustomDataset): """ImageNet21k Dataset. Since the dataset ImageNet21k is extremely big, cantains 21k+ classes and 1.4B files. This class has improved the following points on the - basis of the class ``ImageNet``, in order to save memory usage and time - required : - - - Delete the samples attribute - - using 'slots' create a Data_item tp replace dict - - Modify setting ``info`` dict from function ``load_annotations`` to - function ``prepare_data`` - - using int instead of np.array(..., np.int64) + basis of the class ``ImageNet``, in order to save memory, we enable the + ``serialize_data`` optional by default. With this option, the annotation + won't be stored in the list ``data_infos``, but be serialized as an + array. Args: - data_prefix (str): the prefix of data path - pipeline (list): a list of dict, where each element represents - a operation defined in ``mmcls.datasets.pipelines`` - ann_file (str | None): the annotation file. When ann_file is str, - the subclass is expected to read from the ann_file. When ann_file - is None, the subclass is expected to read according to data_prefix - test_mode (bool): in train mode or test mode - multi_label (bool): use multi label or not. - recursion_subdir(bool): whether to use sub-directory pictures, which - are meet the conditions in the folder under category directory. + data_prefix (str): The path of data directory. + pipeline (Sequence[dict]): A list of dict, where each element + represents a operation defined in :mod:`mmcls.datasets.pipelines`. + Defaults to an empty tuple. + classes (str | Sequence[str], optional): Specify names of classes. + + - If is string, it should be a file path, and the every line of + the file is a name of a class. + - If is a sequence of string, every item is a name of class. + - If is None, use ``cls.CLASSES`` or the names of sub folders + (If use the second way to arrange samples). + + Defaults to None. + ann_file (str, optional): The annotation file. If is string, read + samples paths from the ann_file. If is None, find samples in + ``data_prefix``. Defaults to None. + serialize_data (bool): Whether to hold memory using serialized objects, + when enabled, data loader workers can use shared RAM from master + process instead of making a copy. Defaults to True. + multi_label (bool): Not implement by now. Use multi label or not. + Defaults to False. + recursion_subdir(bool): Deprecated, and the dataset will recursively + get all images now. + test_mode (bool): In train mode or test mode. It's only a mark and + won't be used in this class. Defaults to False. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + If None, automatically inference from the specified path. + Defaults to None. """ - IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', - '.JPEG', '.JPG') + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') CLASSES = None def __init__(self, - data_prefix, - pipeline, - classes=None, - ann_file=None, - multi_label=False, - recursion_subdir=False, - test_mode=False): - self.recursion_subdir = recursion_subdir + data_prefix: str, + pipeline: Sequence = (), + classes: Union[str, Sequence[str], None] = None, + ann_file: Optional[str] = None, + serialize_data: bool = True, + multi_label: bool = False, + recursion_subdir: bool = True, + test_mode=False, + file_client_args: Optional[dict] = None): + assert recursion_subdir, 'The `recursion_subdir` option is ' \ + 'deprecated. Now the dataset will recursively get all images.' if multi_label: - raise NotImplementedError('Multi_label have not be implemented.') - self.multi_lable = multi_label - super(ImageNet21k, self).__init__(data_prefix, pipeline, classes, - ann_file, test_mode) + raise NotImplementedError( + 'The `multi_label` option is not supported by now.') + self.multi_label = multi_label + self.serialize_data = serialize_data + + if ann_file is None: + warnings.warn( + 'The ImageNet21k dataset is large, and scanning directory may ' + 'consume long time. Considering to specify the `ann_file` to ' + 'accelerate the initialization.', UserWarning) + + if classes is None: + warnings.warn( + 'The CLASSES is not stored in the `ImageNet21k` class. ' + 'Considering to specify the `classes` argument if you need ' + 'do inference on the ImageNet-21k dataset', UserWarning) + + super().__init__( + data_prefix=data_prefix, + pipeline=pipeline, + classes=classes, + ann_file=ann_file, + extensions=self.IMG_EXTENSIONS, + test_mode=test_mode, + file_client_args=file_client_args) + + if self.serialize_data: + self.data_infos_bytes, self.data_address = self._serialize_data() + # Empty cache for preventing making multiple copies of + # `self.data_infos` when loading data multi-processes. + self.data_infos.clear() + gc.collect() def get_cat_ids(self, idx: int) -> List[int]: """Get category id by index. @@ -78,77 +112,63 @@ def get_cat_ids(self, idx: int) -> List[int]: cat_ids (List[int]): Image category of specified index. """ - return [self.data_infos[idx].gt_label] + return [int(self.get_data_info(idx)['gt_label'])] + + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): The index of data. + + Returns: + dict: The idx-th annotation of the dataset. + """ + if self.serialize_data: + start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() + end_addr = self.data_address[idx].item() + bytes = memoryview(self.data_infos_bytes[start_addr:end_addr]) + data_info = pickle.loads(bytes) + else: + data_info = self.data_infos[idx] + + return data_info def prepare_data(self, idx): - info = self.data_infos[idx] - results = { - 'img_prefix': self.data_prefix, - 'img_info': dict(filename=info.path), - 'gt_label': np.array(info.gt_label, dtype=np.int64) - } - return self.pipeline(results) - - def load_annotations(self): - """load dataset annotations.""" - if self.ann_file is None: - data_infos = self._load_annotations_from_dir() - elif isinstance(self.ann_file, str): - data_infos = self._load_annotations_from_file() + data_info = self.get_data_info(idx) + return self.pipeline(data_info) + + def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]: + """Serialize ``self.data_infos`` to save memory when launching multiple + workers in data loading. This function will be called in ``full_init``. + + Hold memory using serialized objects, and data loader workers can use + shared RAM from master process instead of making a copy. + + Returns: + Tuple[np.ndarray, np.ndarray]: serialize result and corresponding + address. + """ + + def _serialize(data): + buffer = pickle.dumps(data, protocol=4) + return np.frombuffer(buffer, dtype=np.uint8) + + serialized_data_infos_list = [_serialize(x) for x in self.data_infos] + address_list = np.asarray([len(x) for x in serialized_data_infos_list], + dtype=np.int64) + data_address: np.ndarray = np.cumsum(address_list) + serialized_data_infos = np.concatenate(serialized_data_infos_list) + + return serialized_data_infos, data_address + + def __len__(self) -> int: + """Get the length of filtered dataset and automatically call + ``full_init`` if the dataset has not been fully init. + + Returns: + int: The length of filtered dataset. + """ + if self.serialize_data: + return len(self.data_address) else: - raise TypeError('ann_file must be a str or None') - - if len(data_infos) == 0: - msg = 'Found no valid file in ' - msg += f'{self.ann_file}. ' if self.ann_file \ - else f'{self.data_prefix}. ' - msg += 'Supported extensions are: ' + \ - ', '.join(self.IMG_EXTENSIONS) - raise RuntimeError(msg) - - return data_infos - - def _find_allowed_files(self, root, folder_name): - """find all the allowed files in a folder, including sub folder if - recursion_subdir is true.""" - _dir = os.path.join(root, folder_name) - infos_pre_class = [] - for path in scandir(_dir, self.IMG_EXTENSIONS, self.recursion_subdir): - path = os.path.join(folder_name, path) - item = ImageInfo(path, self.folder_to_idx[folder_name]) - infos_pre_class.append(item) - return infos_pre_class - - def _load_annotations_from_dir(self): - """load annotations from self.data_prefix directory.""" - data_infos, empty_classes = [], [] - folder_to_idx = find_folders(self.data_prefix) - self.folder_to_idx = folder_to_idx - root = os.path.expanduser(self.data_prefix) - for folder_name in folder_to_idx.keys(): - infos_pre_class = self._find_allowed_files(root, folder_name) - if len(infos_pre_class) == 0: - empty_classes.append(folder_name) - data_infos.extend(infos_pre_class) - - if len(empty_classes) != 0: - msg = 'Found no valid file for the classes ' + \ - f"{', '.join(sorted(empty_classes))} " - msg += 'Supported extensions are: ' + \ - f"{', '.join(self.IMG_EXTENSIONS)}." - warnings.warn(msg) - - return data_infos - - def _load_annotations_from_file(self): - """load annotations from self.ann_file.""" - data_infos = [] - with open(self.ann_file) as f: - for line in f.readlines(): - if line == '': - continue - filepath, gt_label = line.strip().rsplit(' ', 1) - info = ImageInfo(filepath, int(gt_label)) - data_infos.append(info) - - return data_infos + return len(self.data_infos) diff --git a/tests/test_data/test_datasets/test_common.py b/tests/test_data/test_datasets/test_common.py index dfc5f4f2e83..b6bfe3bd341 100644 --- a/tests/test_data/test_datasets/test_common.py +++ b/tests/test_data/test_datasets/test_common.py @@ -1,378 +1,763 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import pickle import tempfile -from unittest.mock import MagicMock, patch +from unittest import TestCase +from unittest.mock import patch import numpy as np -import pytest import torch -from mmcls.datasets import (CUB, DATASETS, BaseDataset, ImageNet21k, - MultiLabelDataset) - - -@pytest.mark.parametrize('dataset_name', [ - 'MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet', 'VOC', - 'ImageNet21k', 'CUB' -]) -def test_datasets_override_default(dataset_name): - dataset_class = DATASETS.get(dataset_name) - load_annotations_f = dataset_class.load_annotations - ann = [ - dict( - img_prefix='', - img_info=dict(), - gt_label=np.array(0, dtype=np.int64)) - ] - dataset_class.load_annotations = MagicMock(return_value=ann) - - original_classes = dataset_class.CLASSES - - # some datasets need extra argument to init - extra_kwargs_settings = { - 'CUB': - dict( - ann_file=None, - image_class_labels_file=None, - train_test_split_file=None), - } - extra_kwargs = extra_kwargs_settings.get(dataset_name, dict()) - # Test VOC year - if dataset_name == 'VOC': - dataset = dataset_class( - data_prefix='VOC2007', - pipeline=[], - classes=('bus', 'car'), - test_mode=True) - assert dataset.year == 2007 - with pytest.raises(ValueError): - dataset = dataset_class( - data_prefix='VOC', - pipeline=[], - classes=('bus', 'car'), - test_mode=True) - - # Test setting classes as a tuple - dataset = dataset_class( - data_prefix='VOC2007' if dataset_name == 'VOC' else '', - pipeline=[], - classes=('bus', 'car'), - test_mode=True, - **extra_kwargs) - assert dataset.CLASSES == ('bus', 'car') - - # Test get_cat_ids - if dataset_name not in ['ImageNet21k', 'VOC']: - assert isinstance(dataset.get_cat_ids(0), list) - assert len(dataset.get_cat_ids(0)) == 1 - assert isinstance(dataset.get_cat_ids(0)[0], int) - - # Test setting classes as a list - dataset = dataset_class( - data_prefix='VOC2007' if dataset_name == 'VOC' else '', - pipeline=[], - classes=['bus', 'car'], - test_mode=True, - **extra_kwargs) - assert dataset.CLASSES == ['bus', 'car'] - - # Test setting classes through a file - tmp_file = tempfile.NamedTemporaryFile() - with open(tmp_file.name, 'w') as f: - f.write('bus\ncar\n') - dataset = dataset_class( - data_prefix='VOC2007' if dataset_name == 'VOC' else '', - pipeline=[], - classes=tmp_file.name, - test_mode=True, - **extra_kwargs) - tmp_file.close() +from mmcls.datasets import DATASETS +from mmcls.datasets import BaseDataset as _BaseDataset +from mmcls.datasets import MultiLabelDataset as _MultiLabelDataset - assert dataset.CLASSES == ['bus', 'car'] +ASSETS_ROOT = osp.abspath( + osp.join(osp.dirname(__file__), '../../data/dataset')) - # Test overriding not a subset - dataset = dataset_class( - data_prefix='VOC2007' if dataset_name == 'VOC' else '', - pipeline=[], - classes=['foo'], - test_mode=True, - **extra_kwargs) - assert dataset.CLASSES == ['foo'] - - # Test default behavior - dataset = dataset_class( - data_prefix='VOC2007' if dataset_name == 'VOC' else '', - pipeline=[], - **extra_kwargs) - - if dataset_name == 'VOC': - assert dataset.data_prefix == 'VOC2007' - else: - assert dataset.data_prefix == '' - assert not dataset.test_mode - assert dataset.ann_file is None - assert dataset.CLASSES == original_classes - - dataset_class.load_annotations = load_annotations_f - - -@patch.multiple(MultiLabelDataset, __abstractmethods__=set()) -@patch.multiple(BaseDataset, __abstractmethods__=set()) -def test_dataset_evaluation(): - # test multi-class single-label evaluation - dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True) - dataset.data_infos = [ - dict(gt_label=0), - dict(gt_label=0), - dict(gt_label=1), - dict(gt_label=2), - dict(gt_label=1), - dict(gt_label=0) - ] - fake_results = np.array([[0.7, 0, 0.3], [0.5, 0.2, 0.3], [0.4, 0.5, 0.1], - [0, 0, 1], [0, 0, 1], [0, 0, 1]]) - eval_results = dataset.evaluate( - fake_results, - metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'], - metric_options={'topk': 1}) - assert eval_results['precision'] == pytest.approx( - (1 + 1 + 1 / 3) / 3 * 100.0) - assert eval_results['recall'] == pytest.approx( - (2 / 3 + 1 / 2 + 1) / 3 * 100.0) - assert eval_results['f1_score'] == pytest.approx( - (4 / 5 + 2 / 3 + 1 / 2) / 3 * 100.0) - assert eval_results['support'] == 6 - assert eval_results['accuracy'] == pytest.approx(4 / 6 * 100) - - # test input as tensor - fake_results_tensor = torch.from_numpy(fake_results) - eval_results_ = dataset.evaluate( - fake_results_tensor, - metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'], - metric_options={'topk': 1}) - assert eval_results_ == eval_results - - # test thr - eval_results = dataset.evaluate( - fake_results, - metric=['precision', 'recall', 'f1_score', 'accuracy'], - metric_options={ - 'thrs': 0.6, - 'topk': 1 - }) - assert eval_results['precision'] == pytest.approx( - (1 + 0 + 1 / 3) / 3 * 100.0) - assert eval_results['recall'] == pytest.approx((1 / 3 + 0 + 1) / 3 * 100.0) - assert eval_results['f1_score'] == pytest.approx( - (1 / 2 + 0 + 1 / 2) / 3 * 100.0) - assert eval_results['accuracy'] == pytest.approx(2 / 6 * 100) - # thrs must be a number or tuple - with pytest.raises(TypeError): + +class BaseDataset(_BaseDataset): + + def load_annotations(self): + pass + + +class MultiLabelDataset(_MultiLabelDataset): + + def load_annotations(self): + pass + + +DATASETS.module_dict['BaseDataset'] = BaseDataset +DATASETS.module_dict['MultiLabelDataset'] = MultiLabelDataset + + +class TestBaseDataset(TestCase): + DATASET_TYPE = 'BaseDataset' + + DEFAULT_ARGS = dict(data_prefix='', pipeline=[]) + + def test_initialize(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + with patch.object(dataset_class, 'load_annotations'): + # Test default behavior + cfg = {**self.DEFAULT_ARGS, 'classes': None, 'ann_file': None} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, dataset_class.CLASSES) + self.assertFalse(dataset.test_mode) + self.assertIsNone(dataset.ann_file) + + # Test setting classes as a tuple + cfg = {**self.DEFAULT_ARGS, 'classes': ('bus', 'car')} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, ('bus', 'car')) + + # Test setting classes as a tuple + cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, ['bus', 'car']) + + # Test setting classes through a file + classes_file = osp.join(ASSETS_ROOT, 'classes.txt') + cfg = {**self.DEFAULT_ARGS, 'classes': classes_file} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, ['bus', 'car']) + self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1}) + + # Test invalid classes + cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)} + with self.assertRaisesRegex(ValueError, "type "): + dataset_class(**cfg) + + def test_get_cat_ids(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + fake_ann = [ + dict( + img_prefix='', + img_info=dict(), + gt_label=np.array(0, dtype=np.int64)) + ] + + with patch.object(dataset_class, 'load_annotations') as mock_load: + mock_load.return_value = fake_ann + dataset = dataset_class(**self.DEFAULT_ARGS) + + cat_ids = dataset.get_cat_ids(0) + self.assertIsInstance(cat_ids, list) + self.assertEqual(len(cat_ids), 1) + self.assertIsInstance(cat_ids[0], int) + + def test_evaluate(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + fake_ann = [ + dict(gt_label=np.array(0, dtype=np.int64)), + dict(gt_label=np.array(0, dtype=np.int64)), + dict(gt_label=np.array(1, dtype=np.int64)), + dict(gt_label=np.array(2, dtype=np.int64)), + dict(gt_label=np.array(1, dtype=np.int64)), + dict(gt_label=np.array(0, dtype=np.int64)), + ] + + with patch.object(dataset_class, 'load_annotations') as mock_load: + mock_load.return_value = fake_ann + dataset = dataset_class(**self.DEFAULT_ARGS) + + fake_results = np.array([ + [0.7, 0.0, 0.3], + [0.5, 0.2, 0.3], + [0.4, 0.5, 0.1], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + ]) + + eval_results = dataset.evaluate( + fake_results, + metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'], + metric_options={'topk': 1}) + + # Test results + self.assertAlmostEqual( + eval_results['precision'], (1 + 1 + 1 / 3) / 3 * 100.0, places=4) + self.assertAlmostEqual( + eval_results['recall'], (2 / 3 + 1 / 2 + 1) / 3 * 100.0, places=4) + self.assertAlmostEqual( + eval_results['f1_score'], (4 / 5 + 2 / 3 + 1 / 2) / 3 * 100.0, + places=4) + self.assertEqual(eval_results['support'], 6) + self.assertAlmostEqual(eval_results['accuracy'], 4 / 6 * 100, places=4) + + # test indices + eval_results_ = dataset.evaluate( + fake_results[:5], + metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'], + metric_options={'topk': 1}, + indices=range(5)) + self.assertAlmostEqual( + eval_results_['precision'], (1 + 1 + 1 / 2) / 3 * 100.0, places=4) + self.assertAlmostEqual( + eval_results_['recall'], (1 + 1 / 2 + 1) / 3 * 100.0, places=4) + self.assertAlmostEqual( + eval_results_['f1_score'], (1 + 2 / 3 + 2 / 3) / 3 * 100.0, + places=4) + self.assertEqual(eval_results_['support'], 5) + self.assertAlmostEqual( + eval_results_['accuracy'], 4 / 5 * 100, places=4) + + # test input as tensor + fake_results_tensor = torch.from_numpy(fake_results) + eval_results_ = dataset.evaluate( + fake_results_tensor, + metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'], + metric_options={'topk': 1}) + assert eval_results_ == eval_results + + # test thr eval_results = dataset.evaluate( fake_results, metric=['precision', 'recall', 'f1_score', 'accuracy'], metric_options={ - 'thrs': 'thr', + 'thrs': 0.6, 'topk': 1 }) - # test topk and thr as tuple - eval_results = dataset.evaluate( - fake_results, - metric=['precision', 'recall', 'f1_score', 'accuracy'], - metric_options={ - 'thrs': (0.5, 0.6), - 'topk': (1, 2) - }) - assert { - 'precision_thr_0.50', 'precision_thr_0.60', 'recall_thr_0.50', - 'recall_thr_0.60', 'f1_score_thr_0.50', 'f1_score_thr_0.60', - 'accuracy_top-1_thr_0.50', 'accuracy_top-1_thr_0.60', - 'accuracy_top-2_thr_0.50', 'accuracy_top-2_thr_0.60' - } == eval_results.keys() - assert type(eval_results['precision_thr_0.50']) == float - assert type(eval_results['recall_thr_0.50']) == float - assert type(eval_results['f1_score_thr_0.50']) == float - assert type(eval_results['accuracy_top-1_thr_0.50']) == float - - eval_results = dataset.evaluate( - fake_results, - metric='accuracy', - metric_options={ - 'thrs': 0.5, - 'topk': (1, 2) - }) - assert {'accuracy_top-1', 'accuracy_top-2'} == eval_results.keys() - assert type(eval_results['accuracy_top-1']) == float - - eval_results = dataset.evaluate( - fake_results, - metric='accuracy', - metric_options={ - 'thrs': (0.5, 0.6), - 'topk': 1 - }) - assert {'accuracy_thr_0.50', 'accuracy_thr_0.60'} == eval_results.keys() - assert type(eval_results['accuracy_thr_0.50']) == float - - # test evaluation results for classes - eval_results = dataset.evaluate( - fake_results, - metric=['precision', 'recall', 'f1_score', 'support'], - metric_options={'average_mode': 'none'}) - assert eval_results['precision'].shape == (3, ) - assert eval_results['recall'].shape == (3, ) - assert eval_results['f1_score'].shape == (3, ) - assert eval_results['support'].shape == (3, ) - - # the average_mode method must be valid - with pytest.raises(ValueError): + self.assertAlmostEqual( + eval_results['precision'], (1 + 0 + 1 / 3) / 3 * 100.0, places=4) + self.assertAlmostEqual( + eval_results['recall'], (1 / 3 + 0 + 1) / 3 * 100.0, places=4) + self.assertAlmostEqual( + eval_results['f1_score'], (1 / 2 + 0 + 1 / 2) / 3 * 100.0, + places=4) + self.assertAlmostEqual(eval_results['accuracy'], 2 / 6 * 100, places=4) + + # thrs must be a number or tuple + with self.assertRaises(TypeError): + dataset.evaluate( + fake_results, + metric=['precision', 'recall', 'f1_score', 'accuracy'], + metric_options={ + 'thrs': 'thr', + 'topk': 1 + }) + + # test topk and thr as tuple eval_results = dataset.evaluate( fake_results, - metric='precision', - metric_options={'average_mode': 'micro'}) - with pytest.raises(ValueError): + metric=['precision', 'recall', 'f1_score', 'accuracy'], + metric_options={ + 'thrs': (0.5, 0.6), + 'topk': (1, 2) + }) + self.assertEqual( + { + 'precision_thr_0.50', 'precision_thr_0.60', 'recall_thr_0.50', + 'recall_thr_0.60', 'f1_score_thr_0.50', 'f1_score_thr_0.60', + 'accuracy_top-1_thr_0.50', 'accuracy_top-1_thr_0.60', + 'accuracy_top-2_thr_0.50', 'accuracy_top-2_thr_0.60' + }, eval_results.keys()) + + self.assertIsInstance(eval_results['precision_thr_0.50'], float) + self.assertIsInstance(eval_results['recall_thr_0.50'], float) + self.assertIsInstance(eval_results['f1_score_thr_0.50'], float) + self.assertIsInstance(eval_results['accuracy_top-1_thr_0.50'], float) + + # test topk is tuple while thrs is number eval_results = dataset.evaluate( fake_results, - metric='recall', - metric_options={'average_mode': 'micro'}) - with pytest.raises(ValueError): + metric='accuracy', + metric_options={ + 'thrs': 0.5, + 'topk': (1, 2) + }) + self.assertEqual({'accuracy_top-1', 'accuracy_top-2'}, + eval_results.keys()) + self.assertIsInstance(eval_results['accuracy_top-1'], float) + + # test topk is number while thrs is tuple eval_results = dataset.evaluate( fake_results, - metric='f1_score', - metric_options={'average_mode': 'micro'}) - with pytest.raises(ValueError): + metric='accuracy', + metric_options={ + 'thrs': (0.5, 0.6), + 'topk': 1 + }) + self.assertEqual({'accuracy_thr_0.50', 'accuracy_thr_0.60'}, + eval_results.keys()) + self.assertIsInstance(eval_results['accuracy_thr_0.50'], float) + + # test evaluation results for classes eval_results = dataset.evaluate( fake_results, - metric='support', - metric_options={'average_mode': 'micro'}) - - # the metric must be valid for the dataset - with pytest.raises(ValueError): - eval_results = dataset.evaluate(fake_results, metric='map') - - # test multi-label evaluation - dataset = MultiLabelDataset(data_prefix='', pipeline=[], test_mode=True) - dataset.data_infos = [ - dict(gt_label=[1, 1, 0, -1]), - dict(gt_label=[1, 1, 0, -1]), - dict(gt_label=[0, -1, 1, -1]), - dict(gt_label=[0, 1, 0, -1]), - dict(gt_label=[0, 1, 0, -1]), - ] - fake_results = np.array([[0.9, 0.8, 0.3, 0.2], [0.1, 0.2, 0.2, 0.1], - [0.7, 0.5, 0.9, 0.3], [0.8, 0.1, 0.1, 0.2], - [0.8, 0.1, 0.1, 0.2]]) - - # the metric must be valid - with pytest.raises(ValueError): - metric = 'coverage' - dataset.evaluate(fake_results, metric=metric) - # only one metric - metric = 'mAP' - eval_results = dataset.evaluate(fake_results, metric=metric) - assert 'mAP' in eval_results.keys() - assert 'CP' not in eval_results.keys() - - # multiple metrics - metric = ['mAP', 'CR', 'OF1'] - eval_results = dataset.evaluate(fake_results, metric=metric) - assert 'mAP' in eval_results.keys() - assert 'CR' in eval_results.keys() - assert 'OF1' in eval_results.keys() - assert 'CF1' not in eval_results.keys() - - -def test_dataset_imagenet21k(): - base_dataset_cfg = dict( - data_prefix='tests/data/dataset', pipeline=[], recursion_subdir=True) - - with pytest.raises(NotImplementedError): - # multi_label have not be implemented - dataset_cfg = base_dataset_cfg.copy() - dataset_cfg.update({'multi_label': True}) - dataset = ImageNet21k(**dataset_cfg) - - with pytest.raises(TypeError): - # ann_file must be a string or None - dataset_cfg = base_dataset_cfg.copy() - ann_file = {'path': 'tests/data/dataset/ann.txt'} - dataset_cfg.update({'ann_file': ann_file}) - dataset = ImageNet21k(**dataset_cfg) - - # test with recursion_subdir is True - dataset = ImageNet21k(**base_dataset_cfg) - assert len(dataset) == 3 - assert isinstance(dataset[0], dict) - assert 'img_prefix' in dataset[0] - assert 'img_info' in dataset[0] - assert 'gt_label' in dataset[0] - - # Test get_cat_ids - assert isinstance(dataset.get_cat_ids(0), list) - assert len(dataset.get_cat_ids(0)) == 1 - assert isinstance(dataset.get_cat_ids(0)[0], int) - - # test with recursion_subdir is False - dataset_cfg = base_dataset_cfg.copy() - dataset_cfg['recursion_subdir'] = False - dataset = ImageNet21k(**dataset_cfg) - assert len(dataset) == 2 - assert isinstance(dataset[0], dict) - - # test with load annotation from ann file - dataset_cfg = base_dataset_cfg.copy() - dataset_cfg['ann_file'] = 'tests/data/dataset/ann.txt' - dataset = ImageNet21k(**dataset_cfg) - assert len(dataset) == 3 - assert isinstance(dataset[0], dict) - - -def test_dataset_cub(): - tmp_ann_file = tempfile.NamedTemporaryFile() - tmp_image_class_labels_file = tempfile.NamedTemporaryFile() - tmp_train_test_split_file = tempfile.NamedTemporaryFile() - - with open(tmp_ann_file.name, 'w') as f: - f.write('1 1.txt \n2 2.txt \n') - with open(tmp_image_class_labels_file.name, 'w') as f: - f.write('1 1 \n2 2 \n') - with open(tmp_train_test_split_file.name, 'w') as f: - f.write('1 0 \n2 1 \n') - - # test in train mode - dataset = CUB( - data_prefix='', - pipeline=[], - test_mode=False, - ann_file=tmp_ann_file.name, - image_class_labels_file=tmp_image_class_labels_file.name, - train_test_split_file=tmp_train_test_split_file.name) + metric=['precision', 'recall', 'f1_score', 'support'], + metric_options={'average_mode': 'none'}) + self.assertEqual(eval_results['precision'].shape, (3, )) + self.assertEqual(eval_results['recall'].shape, (3, )) + self.assertEqual(eval_results['f1_score'].shape, (3, )) + self.assertEqual(eval_results['support'].shape, (3, )) + + # the average_mode method must be valid + with self.assertRaises(ValueError): + dataset.evaluate( + fake_results, + metric=['precision', 'recall', 'f1_score', 'support'], + metric_options={'average_mode': 'micro'}) + + # the metric must be valid for the dataset + with self.assertRaisesRegex(ValueError, + "{'unknown'} is not supported"): + dataset.evaluate(fake_results, metric='unknown') + + +class TestMultiLabelDataset(TestBaseDataset): + DATASET_TYPE = 'MultiLabelDataset' + + def test_get_cat_ids(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + fake_ann = [ + dict( + img_prefix='', + img_info=dict(), + gt_label=np.array([0, 1, 1, 0], dtype=np.uint8)) + ] + + with patch.object(dataset_class, 'load_annotations') as mock_load: + mock_load.return_value = fake_ann + dataset = dataset_class(**self.DEFAULT_ARGS) + + cat_ids = dataset.get_cat_ids(0) + self.assertIsInstance(cat_ids, list) + self.assertEqual(len(cat_ids), 2) + self.assertIsInstance(cat_ids[0], int) + self.assertEqual(cat_ids, [1, 2]) + + def test_evaluate(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + fake_ann = [ + dict(gt_label=np.array([1, 1, 0, -1], dtype=np.int8)), + dict(gt_label=np.array([1, 1, 0, -1], dtype=np.int8)), + dict(gt_label=np.array([0, -1, 1, -1], dtype=np.int8)), + dict(gt_label=np.array([0, 1, 0, -1], dtype=np.int8)), + dict(gt_label=np.array([0, 1, 0, -1], dtype=np.int8)), + ] + + with patch.object(dataset_class, 'load_annotations') as mock_load: + mock_load.return_value = fake_ann + dataset = dataset_class(**self.DEFAULT_ARGS) + + fake_results = np.array([ + [0.9, 0.8, 0.3, 0.2], + [0.1, 0.2, 0.2, 0.1], + [0.7, 0.5, 0.9, 0.3], + [0.8, 0.1, 0.1, 0.2], + [0.8, 0.1, 0.1, 0.2], + ]) + + # the metric must be valid for the dataset + with self.assertRaisesRegex(ValueError, + "{'unknown'} is not supported"): + dataset.evaluate(fake_results, metric='unknown') + + # only one metric + eval_results = dataset.evaluate(fake_results, metric='mAP') + self.assertEqual(eval_results.keys(), {'mAP'}) + self.assertAlmostEqual(eval_results['mAP'], 67.5, places=4) + + # multiple metrics + eval_results = dataset.evaluate( + fake_results, metric=['mAP', 'CR', 'OF1']) + self.assertEqual(eval_results.keys(), {'mAP', 'CR', 'OF1'}) + self.assertAlmostEqual(eval_results['mAP'], 67.50, places=2) + self.assertAlmostEqual(eval_results['CR'], 43.75, places=2) + self.assertAlmostEqual(eval_results['OF1'], 42.86, places=2) + + +class TestCustomDataset(TestBaseDataset): + DATASET_TYPE = 'CustomDataset' + + def test_load_annotations(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # test load without ann_file + cfg = { + **self.DEFAULT_ARGS, + 'data_prefix': ASSETS_ROOT, + 'ann_file': None, + } + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 3) + self.assertEqual(dataset.CLASSES, ['a', 'b']) # auto infer classes + self.assertEqual( + dataset.data_infos[0], { + 'img_prefix': ASSETS_ROOT, + 'img_info': { + 'filename': 'a/1.JPG' + }, + 'gt_label': np.array(0) + }) + self.assertEqual( + dataset.data_infos[2], { + 'img_prefix': ASSETS_ROOT, + 'img_info': { + 'filename': 'b/subb/3.jpg' + }, + 'gt_label': np.array(1) + }) - assert len(dataset) == 1 + # test ann_file assertion + cfg = { + **self.DEFAULT_ARGS, + 'data_prefix': ASSETS_ROOT, + 'ann_file': ['ann_file.txt'], + } + with self.assertRaisesRegex(TypeError, 'must be a str'): + dataset_class(**cfg) + + # test load with ann_file + cfg = { + **self.DEFAULT_ARGS, + 'data_prefix': ASSETS_ROOT, + 'ann_file': osp.join(ASSETS_ROOT, 'ann.txt'), + } + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 3) + # custom dataset won't infer CLASSES from ann_file + self.assertEqual(dataset.CLASSES, dataset_class.CLASSES) + self.assertEqual( + dataset.data_infos[0], { + 'img_prefix': ASSETS_ROOT, + 'img_info': { + 'filename': 'a/1.JPG' + }, + 'gt_label': np.array(0) + }) + self.assertEqual( + dataset.data_infos[2], { + 'img_prefix': ASSETS_ROOT, + 'img_info': { + 'filename': 'b/subb/2.jpeg' + }, + 'gt_label': np.array(1) + }) + + # test extensions filter + cfg = { + **self.DEFAULT_ARGS, 'data_prefix': ASSETS_ROOT, + 'ann_file': None, + 'extensions': ('.txt', ) + } + with self.assertRaisesRegex(RuntimeError, + 'Supported extensions are: .txt'): + dataset_class(**cfg) + + cfg = { + **self.DEFAULT_ARGS, 'data_prefix': ASSETS_ROOT, + 'ann_file': None, + 'extensions': ('.jpeg', ) + } + with self.assertWarnsRegex(UserWarning, + 'Supported extensions are: .jpeg'): + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 1) + self.assertEqual( + dataset.data_infos[0], { + 'img_prefix': ASSETS_ROOT, + 'img_info': { + 'filename': 'b/2.jpeg' + }, + 'gt_label': np.array(1) + }) - # test in test mode - dataset = CUB( - data_prefix='', + # test classes check + cfg = { + **self.DEFAULT_ARGS, + 'data_prefix': ASSETS_ROOT, + 'classes': ['apple', 'banana'], + 'ann_file': None, + } + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, ['apple', 'banana']) + + cfg['classes'] = ['apple', 'banana', 'dog'] + with self.assertRaisesRegex(AssertionError, + r"\(2\) doesn't match .* classes \(3\)"): + dataset_class(**cfg) + + +class TestImageNet(TestBaseDataset): + DATASET_TYPE = 'ImageNet' + + def test_load_annotations(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # test classes number + cfg = { + **self.DEFAULT_ARGS, + 'data_prefix': ASSETS_ROOT, + 'ann_file': None, + } + with self.assertRaisesRegex( + AssertionError, r"\(2\) doesn't match .* classes \(1000\)"): + dataset_class(**cfg) + + # test override classes + cfg = { + **self.DEFAULT_ARGS, + 'data_prefix': ASSETS_ROOT, + 'classes': ['cat', 'dog'], + 'ann_file': None, + } + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 3) + self.assertEqual(dataset.CLASSES, ['cat', 'dog']) + + +class TestImageNet21k(TestBaseDataset): + DATASET_TYPE = 'ImageNet21k' + + DEFAULT_ARGS = dict( + data_prefix=ASSETS_ROOT, pipeline=[], - test_mode=True, - ann_file=tmp_ann_file.name, - image_class_labels_file=tmp_image_class_labels_file.name, - train_test_split_file=tmp_train_test_split_file.name) - - assert len(dataset) == 1 - - # test with different items in three files - with open(tmp_train_test_split_file.name, 'w') as f: - f.write('1 0 \n') - with pytest.raises(AssertionError, match='should have same length'): - dataset = CUB( - data_prefix='', + classes=['cat', 'dog'], + ann_file=osp.join(ASSETS_ROOT, 'ann.txt'), + serialize_data=False) + + def test_initialize(self): + super().test_initialize() + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # The multi_label option is not implemented not. + cfg = {**self.DEFAULT_ARGS, 'multi_label': True} + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + dataset_class(**cfg) + + # Warn about ann_file + cfg = {**self.DEFAULT_ARGS, 'ann_file': None} + with self.assertWarnsRegex(UserWarning, 'specify the `ann_file`'): + dataset_class(**cfg) + + # Warn about classes + cfg = {**self.DEFAULT_ARGS, 'classes': None} + with self.assertWarnsRegex(UserWarning, 'specify the `classes`'): + dataset_class(**cfg) + + def test_load_annotations(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # Test with serialize_data=False + cfg = {**self.DEFAULT_ARGS, 'serialize_data': False} + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset.data_infos), 3) + self.assertEqual(len(dataset), 3) + self.assertEqual( + dataset[0], { + 'img_prefix': ASSETS_ROOT, + 'img_info': { + 'filename': 'a/1.JPG' + }, + 'gt_label': np.array(0) + }) + self.assertEqual( + dataset[2], { + 'img_prefix': ASSETS_ROOT, + 'img_info': { + 'filename': 'b/subb/2.jpeg' + }, + 'gt_label': np.array(1) + }) + + # Test with serialize_data=True + cfg = {**self.DEFAULT_ARGS, 'serialize_data': True} + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset.data_infos), 0) # data_infos is clear. + self.assertEqual(len(dataset), 3) + self.assertEqual( + dataset[0], { + 'img_prefix': ASSETS_ROOT, + 'img_info': { + 'filename': 'a/1.JPG' + }, + 'gt_label': np.array(0) + }) + self.assertEqual( + dataset[2], { + 'img_prefix': ASSETS_ROOT, + 'img_info': { + 'filename': 'b/subb/2.jpeg' + }, + 'gt_label': np.array(1) + }) + + +class TestMNIST(TestBaseDataset): + DATASET_TYPE = 'MNIST' + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + tmpdir = tempfile.TemporaryDirectory() + cls.tmpdir = tmpdir + data_prefix = tmpdir.name + cls.DEFAULT_ARGS = dict(data_prefix=data_prefix, pipeline=[]) + + dataset_class = DATASETS.get(cls.DATASET_TYPE) + + def rm_suffix(s): + return s[:s.rfind('.')] + + train_image_file = osp.join( + data_prefix, + rm_suffix(dataset_class.resources['train_image_file'][0])) + train_label_file = osp.join( + data_prefix, + rm_suffix(dataset_class.resources['train_label_file'][0])) + test_image_file = osp.join( + data_prefix, + rm_suffix(dataset_class.resources['test_image_file'][0])) + test_label_file = osp.join( + data_prefix, + rm_suffix(dataset_class.resources['test_label_file'][0])) + cls.fake_img = np.random.randint(0, 255, size=(28, 28), dtype=np.uint8) + cls.fake_label = np.random.randint(0, 10, size=(1, ), dtype=np.uint8) + + for file in [train_image_file, test_image_file]: + magic = b'\x00\x00\x08\x03' # num_dims = 3, type = uint8 + head = b'\x00\x00\x00\x01' + b'\x00\x00\x00\x1c' * 2 # (1, 28, 28) + data = magic + head + cls.fake_img.flatten().tobytes() + with open(file, 'wb') as f: + f.write(data) + + for file in [train_label_file, test_label_file]: + magic = b'\x00\x00\x08\x01' # num_dims = 3, type = uint8 + head = b'\x00\x00\x00\x01' # (1, ) + data = magic + head + cls.fake_label.tobytes() + with open(file, 'wb') as f: + f.write(data) + + def test_load_annotations(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + with patch.object(dataset_class, 'download'): + # Test default behavior + dataset = dataset_class(**self.DEFAULT_ARGS) + self.assertEqual(len(dataset), 1) + + data_info = dataset[0] + np.testing.assert_equal(data_info['img'], self.fake_img) + np.testing.assert_equal(data_info['gt_label'], self.fake_label) + + @classmethod + def tearDownClass(cls): + cls.tmpdir.cleanup() + + +class TestCIFAR10(TestBaseDataset): + DATASET_TYPE = 'CIFAR10' + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + tmpdir = tempfile.TemporaryDirectory() + cls.tmpdir = tmpdir + data_prefix = tmpdir.name + cls.DEFAULT_ARGS = dict(data_prefix=data_prefix, pipeline=[]) + + dataset_class = DATASETS.get(cls.DATASET_TYPE) + base_folder = osp.join(data_prefix, dataset_class.base_folder) + os.mkdir(base_folder) + + cls.fake_imgs = np.random.randint( + 0, 255, size=(6, 3 * 32 * 32), dtype=np.uint8) + cls.fake_labels = np.random.randint(0, 10, size=(6, )) + cls.fake_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + batch1 = dict( + data=cls.fake_imgs[:2], labels=cls.fake_labels[:2].tolist()) + with open(osp.join(base_folder, 'data_batch_1'), 'wb') as f: + f.write(pickle.dumps(batch1)) + + batch2 = dict( + data=cls.fake_imgs[2:4], labels=cls.fake_labels[2:4].tolist()) + with open(osp.join(base_folder, 'data_batch_2'), 'wb') as f: + f.write(pickle.dumps(batch2)) + + test_batch = dict( + data=cls.fake_imgs[4:], labels=cls.fake_labels[4:].tolist()) + with open(osp.join(base_folder, 'test_batch'), 'wb') as f: + f.write(pickle.dumps(test_batch)) + + meta = {dataset_class.meta['key']: cls.fake_classes} + meta_filename = dataset_class.meta['filename'] + with open(osp.join(base_folder, meta_filename), 'wb') as f: + f.write(pickle.dumps(meta)) + + dataset_class.train_list = [['data_batch_1', None], + ['data_batch_2', None]] + dataset_class.test_list = [['test_batch', None]] + dataset_class.meta['md5'] = None + + def test_load_annotations(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # Test default behavior + dataset = dataset_class(**self.DEFAULT_ARGS) + self.assertEqual(len(dataset), 4) + self.assertEqual(dataset.CLASSES, self.fake_classes) + + data_info = dataset[0] + fake_img = self.fake_imgs[0].reshape(3, 32, 32).transpose(1, 2, 0) + np.testing.assert_equal(data_info['img'], fake_img) + np.testing.assert_equal(data_info['gt_label'], self.fake_labels[0]) + + # Test with test_mode=True + cfg = {**self.DEFAULT_ARGS, 'test_mode': True} + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 2) + + data_info = dataset[0] + fake_img = self.fake_imgs[4].reshape(3, 32, 32).transpose(1, 2, 0) + np.testing.assert_equal(data_info['img'], fake_img) + np.testing.assert_equal(data_info['gt_label'], self.fake_labels[4]) + + @classmethod + def tearDownClass(cls): + cls.tmpdir.cleanup() + + +class TestCIFAR100(TestCIFAR10): + DATASET_TYPE = 'CIFAR100' + + +class TestVOC(TestMultiLabelDataset): + DATASET_TYPE = 'VOC' + + DEFAULT_ARGS = dict(data_prefix='VOC2007', pipeline=[]) + + +class TestCUB(TestBaseDataset): + DATASET_TYPE = 'CUB' + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + tmpdir = tempfile.TemporaryDirectory() + cls.tmpdir = tmpdir + cls.data_prefix = tmpdir.name + cls.ann_file = osp.join(cls.data_prefix, 'ann_file.txt') + cls.image_class_labels_file = osp.join(cls.data_prefix, 'classes.txt') + cls.train_test_split_file = osp.join(cls.data_prefix, 'split.txt') + cls.train_test_split_file2 = osp.join(cls.data_prefix, 'split2.txt') + cls.DEFAULT_ARGS = dict( + data_prefix=cls.data_prefix, pipeline=[], - test_mode=True, - ann_file=tmp_ann_file.name, - image_class_labels_file=tmp_image_class_labels_file.name, - train_test_split_file=tmp_train_test_split_file.name) - - tmp_ann_file.close() - tmp_image_class_labels_file.close() - tmp_train_test_split_file.close() + ann_file=cls.ann_file, + image_class_labels_file=cls.image_class_labels_file, + train_test_split_file=cls.train_test_split_file) + + with open(cls.ann_file, 'w') as f: + f.write('\n'.join([ + '1 1.txt', + '2 2.txt', + '3 3.txt', + ])) + + with open(cls.image_class_labels_file, 'w') as f: + f.write('\n'.join([ + '1 2', + '2 3', + '3 1', + ])) + + with open(cls.train_test_split_file, 'w') as f: + f.write('\n'.join([ + '1 0', + '2 1', + '3 1', + ])) + + with open(cls.train_test_split_file2, 'w') as f: + f.write('\n'.join([ + '1 0', + '2 1', + ])) + + def test_load_annotations(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # Test default behavior + dataset = dataset_class(**self.DEFAULT_ARGS) + self.assertEqual(len(dataset), 2) + self.assertEqual(dataset.CLASSES, dataset_class.CLASSES) + + data_info = dataset[0] + np.testing.assert_equal(data_info['img_prefix'], self.data_prefix) + np.testing.assert_equal(data_info['img_info'], {'filename': '2.txt'}) + np.testing.assert_equal(data_info['gt_label'], 3 - 1) + + # Test with test_mode=True + cfg = {**self.DEFAULT_ARGS, 'test_mode': True} + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 1) + + data_info = dataset[0] + np.testing.assert_equal(data_info['img_prefix'], self.data_prefix) + np.testing.assert_equal(data_info['img_info'], {'filename': '1.txt'}) + np.testing.assert_equal(data_info['gt_label'], 2 - 1) + + # Test if the numbers of line are not match + cfg = { + **self.DEFAULT_ARGS, 'train_test_split_file': + self.train_test_split_file2 + } + with self.assertRaisesRegex(AssertionError, 'should have same length'): + dataset_class(**cfg) + + @classmethod + def tearDownClass(cls): + cls.tmpdir.cleanup()