diff --git a/CHANGELOG.md b/CHANGELOG.md index b138c11aa9..e25fe4dc15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Convert Cuboid2D annotation to/from 3D data () +- Add label groups for hierarchical classification in ImageNet + () ### Enhancements - Enhance 'id_from_image_name' transform to ensure each identifier is unique diff --git a/src/datumaro/plugins/data_formats/imagenet.py b/src/datumaro/plugins/data_formats/imagenet.py index 673cfd132d..30b9dfaa5a 100644 --- a/src/datumaro/plugins/data_formats/imagenet.py +++ b/src/datumaro/plugins/data_formats/imagenet.py @@ -48,8 +48,18 @@ def _load_categories(self, path): path = Path(path) for dirname in sorted(d for d in path.rglob("*") if d.is_dir()): dirname = dirname.relative_to(path) + level = len(dirname.parts) if str(dirname) != ImagenetPath.IMAGE_DIR_NO_LABEL: - label_cat.add(str(dirname)) + parent = None + if level > 1: + parent = str(dirname.parents[0]) + if not any([g.name == parent for g in label_cat.label_groups]): + label_cat.add_label_group(parent, [str(dirname.name)], group_type=0) + else: + g = next(x for x in label_cat.label_groups if x.name == parent) + g.labels.append(str(dirname.name)) + label_cat.add(str(dirname), parent) + return {AnnotationType.label: label_cat} def _load_items(self, path): diff --git a/tests/unit/test_imagenet_format.py b/tests/unit/test_imagenet_format.py index e84b9406ea..5bd8a1c9bb 100644 --- a/tests/unit/test_imagenet_format.py +++ b/tests/unit/test_imagenet_format.py @@ -182,6 +182,12 @@ class ImagenetImporterTest: IMPORTER_NAME = ImagenetImporter.NAME def _create_expected_dataset(self): + label_categories = LabelCategories.from_iterable( + ("label_0", "label_1", f"{Path('label_1', 'label_1_1')}") + ) + label_categories[-1].parent = "label_1" + label_categories.add_label_group(name="label_1", labels=["label_1_1"], group_type=0) + return Dataset.from_iterable( [ DatasetItem( @@ -204,11 +210,7 @@ def _create_expected_dataset(self): annotations=[Label(1)], ), ], - categories={ - AnnotationType.label: LabelCategories.from_iterable( - ("label_0", "label_1", f"{Path('label_1', 'label_1_1')}") - ), - }, + categories={AnnotationType.label: label_categories}, ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index f9a2e72d59..3a25c1d158 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -108,6 +108,17 @@ def compare_categories(test, expected, actual): sorted(expected[AnnotationType.label].items, key=lambda t: t.name), sorted(actual[AnnotationType.label].items, key=lambda t: t.name), ) + if expected[AnnotationType.label].label_groups: + assert len(expected[AnnotationType.label].label_groups) == len( + actual[AnnotationType.label].label_groups + ) + for expected_group, actual_group in zip( + expected[AnnotationType.label].label_groups, + actual[AnnotationType.label].label_groups, + ): + test.assertEqual(set(expected_group.labels), set(actual_group.labels)) + test.assertEqual(expected_group.group_type, actual_group.group_type) + if AnnotationType.mask in expected: test.assertEqual( expected[AnnotationType.mask].colormap,