-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add CustomDataset
#738
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import warnings | ||
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | ||
|
||
import mmcv | ||
import numpy as np | ||
from mmcv import FileClient | ||
|
||
from .base_dataset import BaseDataset | ||
from .builder import DATASETS | ||
|
||
|
||
def find_folders(root: str, | ||
file_client: FileClient) -> Tuple[List[str], Dict[str, int]]: | ||
"""Find classes by folders under a root. | ||
|
||
Args: | ||
root (string): root directory of folders | ||
|
||
Returns: | ||
Tuple[List[str], Dict[str, int]]: | ||
|
||
- folders: The name of sub folders under the root. | ||
- folder_to_idx: The map from folder name to class idx. | ||
""" | ||
folders = list( | ||
file_client.list_dir_or_file( | ||
root, | ||
list_dir=True, | ||
list_file=False, | ||
recursive=False, | ||
)) | ||
folders.sort() | ||
folder_to_idx = {folders[i]: i for i in range(len(folders))} | ||
return folders, folder_to_idx | ||
|
||
|
||
def get_samples(root: str, folder_to_idx: Dict[str, int], | ||
is_valid_file: Callable, file_client: FileClient): | ||
"""Make dataset by walking all images under a root. | ||
|
||
Args: | ||
root (string): root directory of folders | ||
folder_to_idx (dict): the map from class name to class idx | ||
is_valid_file (Callable): A function that takes path of a file | ||
and check if the file is a valid sample file. | ||
|
||
Returns: | ||
Tuple[list, set]: | ||
|
||
- samples: a list of tuple where each element is (image, class_idx) | ||
- empty_folders: The folders don't have any valid files. | ||
""" | ||
samples = [] | ||
available_classes = set() | ||
|
||
for folder_name in sorted(list(folder_to_idx.keys())): | ||
_dir = file_client.join_path(root, folder_name) | ||
files = list( | ||
file_client.list_dir_or_file( | ||
_dir, | ||
list_dir=False, | ||
list_file=True, | ||
recursive=True, | ||
)) | ||
for file in sorted(list(files)): | ||
if is_valid_file(file): | ||
path = file_client.join_path(folder_name, file) | ||
item = (path, folder_to_idx[folder_name]) | ||
samples.append(item) | ||
available_classes.add(folder_name) | ||
|
||
empty_folders = set(folder_to_idx.keys()) - available_classes | ||
|
||
return samples, empty_folders | ||
|
||
|
||
@DATASETS.register_module() | ||
class CustomDataset(BaseDataset): | ||
"""Custom dataset for classification. | ||
|
||
The dataset supports two kinds of annotation format. | ||
|
||
1. An annotation file is provided, and each line indicates a sample: | ||
|
||
The sample files: :: | ||
|
||
data_prefix/ | ||
├── folder_1 | ||
│ ├── xxx.png | ||
│ ├── xxy.png | ||
│ └── ... | ||
└── folder_2 | ||
├── 123.png | ||
├── nsdf3.png | ||
└── ... | ||
|
||
The annotation file (the first column is the image path and the second | ||
column is the index of category): :: | ||
|
||
folder_1/xxx.png 0 | ||
folder_1/xxy.png 1 | ||
folder_2/123.png 5 | ||
mzr1996 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
folder_2/nsdf3.png 3 | ||
... | ||
|
||
Please specify the name of categories by the argument ``classes``. | ||
|
||
2. The samples are arranged in the specific way: :: | ||
|
||
data_prefix/ | ||
├── class_x | ||
│ ├── xxx.png | ||
│ ├── xxy.png | ||
│ └── ... | ||
│ └── xxz.png | ||
└── class_y | ||
├── 123.png | ||
├── nsdf3.png | ||
├── ... | ||
└── asd932_.png | ||
|
||
If the ``ann_file`` is specified, the dataset will be generated by the | ||
first way, otherwise, try the second way. | ||
|
||
Args: | ||
data_prefix (str): The path of data directory. | ||
pipeline (Sequence[dict]): A list of dict, where each element | ||
represents a operation defined in :mod:`mmcls.datasets.pipelines`. | ||
Defaults to an empty tuple. | ||
classes (str | Sequence[str], optional): Specify names of classes. | ||
|
||
- If is string, it should be a file path, and the every line of | ||
the file is a name of a class. | ||
- If is a sequence of string, every item is a name of class. | ||
- If is None, use ``cls.CLASSES`` or the names of sub folders | ||
(If use the second way to arrange samples). | ||
|
||
Defaults to None. | ||
ann_file (str, optional): The annotation file. If is string, read | ||
samples paths from the ann_file. If is None, find samples in | ||
``data_prefix``. Defaults to None. | ||
extensions (Sequence[str]): A sequence of allowed extensions. Defaults | ||
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'). | ||
test_mode (bool): In train mode or test mode. It's only a mark and | ||
won't be used in this class. Defaults to False. | ||
file_client_args (dict, optional): Arguments to instantiate a | ||
FileClient. See :class:`mmcv.fileio.FileClient` for details. | ||
If None, automatically inference from the specified path. | ||
Defaults to None. | ||
""" | ||
|
||
def __init__(self, | ||
data_prefix: str, | ||
pipeline: Sequence = (), | ||
classes: Union[str, Sequence[str], None] = None, | ||
ann_file: Optional[str] = None, | ||
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', | ||
'.bmp', '.pgm', '.tif'), | ||
test_mode: bool = False, | ||
file_client_args: Optional[dict] = None): | ||
self.extensions = tuple(set([i.lower() for i in extensions])) | ||
self.file_client_args = file_client_args | ||
|
||
super().__init__( | ||
data_prefix=data_prefix, | ||
pipeline=pipeline, | ||
classes=classes, | ||
ann_file=ann_file, | ||
test_mode=test_mode) | ||
|
||
def _find_samples(self): | ||
"""find samples from ``data_prefix``.""" | ||
file_client = FileClient.infer_client(self.file_client_args, | ||
self.data_prefix) | ||
classes, folder_to_idx = find_folders(self.data_prefix, file_client) | ||
samples, empty_classes = get_samples( | ||
self.data_prefix, | ||
folder_to_idx, | ||
is_valid_file=self.is_valid_file, | ||
file_client=file_client, | ||
) | ||
|
||
if len(samples) == 0: | ||
raise RuntimeError( | ||
f'Found 0 files in subfolders of: {self.data_prefix}. ' | ||
f'Supported extensions are: {",".join(self.extensions)}') | ||
|
||
if self.CLASSES is not None: | ||
assert len(self.CLASSES) == len(classes), \ | ||
f"The number of subfolders ({len(classes)}) doesn't match " \ | ||
f'the number of specified classes ({len(self.CLASSES)}). ' \ | ||
'Please check the data folder.' | ||
else: | ||
self.CLASSES = classes | ||
|
||
if empty_classes: | ||
warnings.warn( | ||
'Found no valid file in the folder ' | ||
f'{", ".join(empty_classes)}. ' | ||
f"Supported extensions are: {', '.join(self.extensions)}", | ||
UserWarning) | ||
|
||
self.folder_to_idx = folder_to_idx | ||
|
||
return samples | ||
|
||
def load_annotations(self): | ||
"""Load image paths and gt_labels.""" | ||
if self.ann_file is None: | ||
samples = self._find_samples() | ||
elif isinstance(self.ann_file, str): | ||
lines = mmcv.list_from_file( | ||
self.ann_file, file_client_args=self.file_client_args) | ||
samples = [x.strip().rsplit(' ', 1) for x in lines] | ||
else: | ||
raise TypeError('ann_file must be a str or None') | ||
|
||
data_infos = [] | ||
for filename, gt_label in samples: | ||
info = {'img_prefix': self.data_prefix} | ||
info['img_info'] = {'filename': filename} | ||
info['gt_label'] = np.array(gt_label, dtype=np.int64) | ||
data_infos.append(info) | ||
return data_infos | ||
|
||
def is_valid_file(self, filename: str) -> bool: | ||
"""Check if a file is a valid sample.""" | ||
return filename.lower().endswith(self.extensions) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the meta info of CLASSES is in our code, it is so long, especially in 'imagenet.py'. It may be better to create a metafile.bin that saves all the CLASSES info. In that way, the code will be purer and users may read our code easily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think so too, but move the categories info to another file may cause unexpected problem, especially for deployment.
The code of
ImageNet
is short, which is only aCustomDataset
with preset attributes. I think we can keep it.