Skip to content

Commit

Permalink
Support hierarchical structure for ImageNet format (#1562)
Browse files Browse the repository at this point in the history
Duplicate of #1528 but with `releases/1.8.0` base.

---------

Signed-off-by: Ilya Trushkin <[email protected]>
  • Loading branch information
itrushkin authored Jul 10, 2024
1 parent 4ae4326 commit b2b8533
Show file tree
Hide file tree
Showing 18 changed files with 135 additions and 106 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1492>)
- Pass Keyword Argument to TabularDataBase
(<https://github.com/openvinotoolkit/datumaro/pull/1522>)
- Support hierarchical structure for ImageNet dataset format
(<https://github.com/openvinotoolkit/datumaro/pull/1528>)

### Bug fixes
- Preserve end_frame information of a video when it is zero.
Expand Down
2 changes: 1 addition & 1 deletion src/datumaro/cli/commands/downloaders/kaggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def make_all_paths_absolute(d: Dict, root: str = "."):


KAGGLE_API_KEY_EXISTS = bool(os.environ.get("KAGGLE_KEY")) or os.path.exists(
os.path.join(os.path.expanduser("~"), ".kaggle")
os.path.join(os.path.expanduser("~"), ".kaggle", "kaggle.json")
)


Expand Down
4 changes: 2 additions & 2 deletions src/datumaro/components/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from datumaro.components.errors import DatasetImportError, DatasetNotFoundError
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext
from datumaro.components.merge.extractor_merger import ExtractorMerger
from datumaro.util.definitions import SUBSET_NAME_BLACKLIST
from datumaro.util.definitions import SUBSET_NAME_WHITELIST

T = TypeVar("T")

Expand Down Expand Up @@ -197,7 +197,7 @@ def _change_context_root_path(context: FormatDetectionContext, path: str):
)

for sub_dir in os.listdir(path):
if sub_dir.lower() in SUBSET_NAME_BLACKLIST:
if sub_dir.lower() not in SUBSET_NAME_WHITELIST:
continue

sub_path = osp.join(path, sub_dir)
Expand Down
28 changes: 20 additions & 8 deletions src/datumaro/plugins/data_formats/image_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import logging as log
import os
import os.path as osp
from pathlib import Path
from typing import List, Optional

from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.exporter import Exporter
from datumaro.components.format_detection import FormatDetectionConfidence
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer
from datumaro.components.media import Image
from datumaro.util.image import IMAGE_EXTENSIONS, find_images
Expand All @@ -31,11 +31,23 @@ def build_cmdline_parser(cls, **kwargs):
)
return parser

@classmethod
def detect(cls, context: FormatDetectionContext) -> FormatDetectionConfidence:
path = Path(context.root_path)
for item in path.iterdir():
if item.is_dir():
context.fail("Only flat image directories are supported")
elif item.suffix.lower() not in IMAGE_EXTENSIONS:
context.fail(f"File {item} is not an image.")
return super().detect(context)

@classmethod
def find_sources(cls, path):
if not osp.isdir(path):
path = Path(path)
if not path.is_dir():
return []
return [{"url": path, "format": ImageDirBase.NAME}]

return [{"url": str(path), "format": ImageDirBase.NAME}]

@classmethod
def get_file_extensions(cls) -> List[str]:
Expand All @@ -51,11 +63,11 @@ def __init__(
ctx: Optional[ImportContext] = None,
):
super().__init__(subset=subset, ctx=ctx)
url = Path(url)
assert url.is_dir(), url

assert osp.isdir(url), url

for path in find_images(url, recursive=True):
item_id = osp.relpath(osp.splitext(path)[0], url)
for path in find_images(str(url)):
item_id = Path(path).stem
self._items.append(
DatasetItem(id=item_id, subset=self._subset, media=Image.from_file(path=path))
)
Expand Down
152 changes: 90 additions & 62 deletions src/datumaro/plugins/data_formats/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import errno
import logging as log
import os
import os.path as osp
import warnings
from typing import List, Optional
from pathlib import Path
from typing import List

from datumaro.components.annotation import AnnotationType, Label, LabelCategories
from datumaro.components.dataset_base import DatasetItem, SubsetBase
Expand All @@ -16,8 +15,9 @@
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer, with_subset_dirs
from datumaro.components.media import Image
from datumaro.util.definitions import SUBSET_NAME_BLACKLIST
from datumaro.util.definitions import SUBSET_NAME_BLACKLIST, SUBSET_NAME_WHITELIST
from datumaro.util.image import IMAGE_EXTENSIONS, find_images
from datumaro.util.os_util import walk


class ImagenetPath:
Expand All @@ -30,40 +30,39 @@ def __init__(
self,
path: str,
*,
subset: Optional[str] = None,
ctx: Optional[ImportContext] = None,
subset: str | None = None,
ctx: ImportContext | None = None,
min_depth: int | None = None,
max_depth: int | None = None,
):
if not osp.isdir(path):
if not Path(path).is_dir():
raise NotADirectoryError(errno.ENOTDIR, "Can't find dataset directory", path)

super().__init__(subset=subset, ctx=ctx)

self._max_depth = min_depth
self._min_depth = max_depth
self._categories = self._load_categories(path)
self._items = list(self._load_items(path).values())

def _load_categories(self, path):
label_cat = LabelCategories()
for dirname in sorted(os.listdir(path)):
if not os.path.isdir(os.path.join(path, dirname)):
warnings.warn(
f"{dirname} is not a directory in the folder {path}, so this will"
"be skipped when declaring the cateogries of `imagenet` dataset."
)
continue
if dirname != ImagenetPath.IMAGE_DIR_NO_LABEL:
label_cat.add(dirname)
path = Path(path)
for dirname in sorted(d for d in path.rglob("*") if d.is_dir()):
dirname = dirname.relative_to(path)
if str(dirname) != ImagenetPath.IMAGE_DIR_NO_LABEL:
label_cat.add(str(dirname))
return {AnnotationType.label: label_cat}

def _load_items(self, path):
items = {}

# Images should be in root/label_dir/*.img and root/*.img is not allowed.
# => max_depth=1, min_depth=1
for image_path in find_images(path, recursive=True, max_depth=1, min_depth=1):
label = osp.basename(osp.dirname(image_path))
image_name = osp.splitext(osp.basename(image_path))[0]

item_id = label + ImagenetPath.SEP_TOKEN + image_name
for image_path in find_images(
path, recursive=True, max_depth=self._max_depth, min_depth=self._min_depth
):
label = str(Path(image_path).parent.relative_to(path))
if label == ".": # image is located in the root directory
label = ImagenetPath.IMAGE_DIR_NO_LABEL
image_name = Path(image_path).stem
item_id = str(label) + ImagenetPath.SEP_TOKEN + image_name
item = items.get(item_id)
try:
if item is None:
Expand All @@ -89,45 +88,70 @@ def _load_items(self, path):


class ImagenetImporter(Importer):
"""TorchVision's ImageFolder style importer.
For example, it imports the following directory structure.
"""
Multi-level version of ImagenetImporter.
For example, it imports the following directory structure.
.. code-block:: text
root
├── label_0
│ ├── label_0_1.jpg
│ └── label_0_2.jpg
│ ├── label_0_1
│ │ └── img1.jpg
│ └── label_0_2
│ └── img2.jpg
└── label_1
└── label_1_1.jpg
└── img3.jpg
"""

_MIN_DEPTH = None
_MAX_DEPTH = None
_FORMAT = ImagenetBase.NAME
DETECT_CONFIDENCE = FormatDetectionConfidence.EXTREME_LOW

@classmethod
def detect(cls, context: FormatDetectionContext) -> FormatDetectionConfidence:
# Images must not be under a directory whose name is blacklisted.
for dname in os.listdir(context.root_path):
dpath = osp.join(context.root_path, dname)
if osp.isdir(dpath) and dname.lower() in SUBSET_NAME_BLACKLIST:
for dname, dirnames, filenames in os.walk(context.root_path):
if dname in SUBSET_NAME_WHITELIST:
context.fail(
f"{dname} is found in {context.root_path}. "
"However, Images must not be under a directory whose name is blacklisted "
f"(SUBSET_NAME_BLACKLIST={SUBSET_NAME_BLACKLIST})."
f"Following directory names are not permitted: {SUBSET_NAME_WHITELIST}"
)
rel_dname = Path(dname).relative_to(context.root_path)
level = len(rel_dname.parts)
if cls._MIN_DEPTH is not None and level < cls._MIN_DEPTH and filenames:
context.fail("Found files out of the directory level bounds.")
if cls._MAX_DEPTH is not None and level > cls._MAX_DEPTH and filenames:
context.fail("Found files out of the directory level bounds.")
dpath = Path(context.root_path) / rel_dname
if dpath.is_dir():
if str(rel_dname).lower() in SUBSET_NAME_BLACKLIST:
context.fail(
f"{dname} is found in {context.root_path}. "
"However, Images must not be under a directory whose name is blacklisted "
f"(SUBSET_NAME_BLACKLIST={SUBSET_NAME_BLACKLIST})."
)

return super().detect(context)

@classmethod
def contains_only_images(cls, path: str | Path):
for _, dirnames, filenames in walk(path, cls._MAX_DEPTH, cls._MIN_DEPTH):
if filenames:
for filename in filenames:
if Path(filename).suffix.lower() not in IMAGE_EXTENSIONS:
return False
elif not dirnames:
return False
return True

@classmethod
def find_sources(cls, path):
if not osp.isdir(path):
if not Path(path).is_dir():
return []

# Images should be in root/label_dir/*.img and root/*.img is not allowed.
# => max_depth=1, min_depth=1
for _ in find_images(path, recursive=True, max_depth=1, min_depth=1):
return [{"url": path, "format": ImagenetBase.NAME}]

return []
return [{"url": path, "format": cls._FORMAT}] if cls.contains_only_images(path) else []

@classmethod
def get_file_extensions(cls) -> List[str]:
Expand All @@ -144,32 +168,36 @@ def build_cmdline_parser(cls, **kwargs):

@with_subset_dirs
class ImagenetWithSubsetDirsImporter(ImagenetImporter):
"""TorchVision ImageFolder style importer.
For example, it imports the following directory structure.
"""Multi-level image directory structure importer.
Example:
.. code-block::
root
├── train
│ ├── label_0
│ │ ├── label_0_1.jpg
│ │ └── label_0_2.jpg
│ │ ├── label_0_1
│ │ │ └── img1.jpg
│ │ └── label_0_2
│ │ └── img2.jpg
│ └── label_1
│ └── label_1_1.jpg
│ └── img3.jpg
├── val
│ ├── label_0
│ │ ├── label_0_1.jpg
│ │ └── label_0_2.jpg
│ │ ├── label_0_1
│ │ │ └── img1.jpg
│ │ └── label_0_2
│ │ └── img2.jpg
│ └── label_1
│ └── label_1_1.jpg
│ └── img3.jpg
└── test
├── label_0
│ ├── label_0_1.jpg
│ └── label_0_2.jpg
│ ├── label_0
│ ├── label_0_1
│ │ └── img1.jpg
│ └── label_0_2
│ └── img2.jpg
└── label_1
└── label_1_1.jpg
Then, it will have three subsets: train, val, and test and they have label_0 and label_1 labels.
└── img3.jpg
"""


Expand Down Expand Up @@ -199,7 +227,7 @@ def _get_name(item: DatasetItem) -> str:
'For example, dataset.export("<path/to/output>", format="imagenet_with_subset_dirs").'
)

root_dir = self._save_dir
root_dir = Path(self._save_dir)
extractor = self._extractor
labels = {}
for item in self._extractor:
Expand All @@ -210,18 +238,18 @@ def _get_name(item: DatasetItem) -> str:
label_name = extractor.categories()[AnnotationType.label][label].name
self._save_image(
item,
subdir=osp.join(root_dir, item.subset, label_name)
subdir=root_dir / item.subset / label_name
if self.USE_SUBSET_DIRS
else osp.join(root_dir, label_name),
else root_dir / label_name,
name=file_name,
)

if not labels:
self._save_image(
item,
subdir=osp.join(root_dir, item.subset, ImagenetPath.IMAGE_DIR_NO_LABEL)
subdir=root_dir / item.subset / ImagenetPath.IMAGE_DIR_NO_LABEL
if self.USE_SUBSET_DIRS
else osp.join(root_dir, ImagenetPath.IMAGE_DIR_NO_LABEL),
else root_dir / ImagenetPath.IMAGE_DIR_NO_LABEL,
name=file_name,
)

Expand Down
6 changes: 3 additions & 3 deletions src/datumaro/plugins/data_formats/yolo/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datumaro.components.exporter import Exporter
from datumaro.components.media import Image
from datumaro.util import str_to_bool
from datumaro.util.definitions import SUBSET_NAME_WHITELIST

from .format import YoloPath

Expand Down Expand Up @@ -195,7 +196,6 @@ def can_stream(self) -> bool:


class YoloUltralyticsExporter(YoloExporter):
allowed_subset_names = {"train", "val", "test"}
must_subset_names = {"train", "val"}

def __init__(self, extractor: IDataset, save_dir: str, **kwargs) -> None:
Expand All @@ -214,9 +214,9 @@ def _check_dataset(self):
subset_names = set(self._extractor.subsets().keys())

for subset in subset_names:
if subset not in self.allowed_subset_names:
if subset not in SUBSET_NAME_WHITELIST:
raise DatasetExportError(
f"The allowed subset name is in {self.allowed_subset_names}, "
f"The allowed subset name should be in {SUBSET_NAME_WHITELIST}, "
f'so that subset "{subset}" is not allowed.'
)

Expand Down
Loading

0 comments on commit b2b8533

Please sign in to comment.