Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change common semantic segmentation dataset detection rule #1572

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# SPDX-License-Identifier: MIT

import errno
import glob
import os.path as osp
from typing import List, Optional

Expand Down Expand Up @@ -69,11 +68,11 @@ def __init__(
self._image_prefix = image_prefix
self._mask_prefix = mask_prefix

meta_file = glob.glob(osp.join(path, "**", DATASET_META_FILE), recursive=True)
if is_meta_file(meta_file[0]):
self._root_dir = osp.dirname(meta_file[0])
meta_file = osp.join(path, DATASET_META_FILE)
if is_meta_file(meta_file):
self._root_dir = osp.dirname(meta_file)

label_map = parse_meta_file(meta_file[0])
label_map = parse_meta_file(meta_file)
self._categories = make_categories(label_map)
else:
raise FileNotFoundError(errno.ENOENT, "Dataset meta info file was not found", path)
Expand Down Expand Up @@ -163,11 +162,10 @@ def build_cmdline_parser(cls, **kwargs):

@classmethod
def detect(cls, context: FormatDetectionContext) -> FormatDetectionConfidence:
path = context.require_file(f"**/{DATASET_META_FILE}")
path = osp.dirname(path)
context.require_file(DATASET_META_FILE)

context.require_file(osp.join(path, CommonSemanticSegmentationPath.IMAGES_DIR, "**", "*"))
context.require_file(osp.join(path, CommonSemanticSegmentationPath.MASKS_DIR, "**", "*"))
context.require_file(osp.join(CommonSemanticSegmentationPath.IMAGES_DIR, "**", "*"))
context.require_file(osp.join(CommonSemanticSegmentationPath.MASKS_DIR, "**", "*"))

return FormatDetectionConfidence.MEDIUM

Expand Down
12 changes: 7 additions & 5 deletions tests/unit/data_formats/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from datumaro import Dataset

from tests.utils.test_utils import TestDir


@pytest.fixture
def fxt_dummy_dataset():
Expand All @@ -35,12 +37,12 @@ def fxt_export_kwargs():
@pytest.fixture
def fxt_dataset_dir_with_subset_dirs(test_dir: str, request: pytest.FixtureRequest):
fxt_dataset_dir = request.param
with TestDir(f"{test_dir}_with_subsets") as new_test_dir:
for subset in ["train", "val", "test"]:
dst = os.path.join(new_test_dir, subset)
shutil.copytree(fxt_dataset_dir, dst)

for subset in ["train", "val", "test"]:
dst = os.path.join(test_dir, subset)
shutil.copytree(fxt_dataset_dir, dst)

yield test_dir
yield new_test_dir


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
#
# SPDX-License-Identifier: MIT

import os
import shutil
from collections import OrderedDict
from typing import Any, Dict, Optional
from typing import Any, Dict

import numpy as np
import pytest

from datumaro.components.annotation import Mask
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.errors import DatasetImportError
from datumaro.components.media import Image
from datumaro.plugins.data_formats.common_semantic_segmentation import (
CommonSemanticSegmentationImporter,
Expand Down Expand Up @@ -143,6 +146,40 @@ def test_can_import(
fxt_dataset_dir, fxt_expected_dataset, fxt_import_kwargs, request
)

@pytest.mark.parametrize(
["fxt_dataset_dir", "fxt_expected_dataset", "fxt_import_kwargs"],
[
(DUMMY_DATASET_DIR, "fxt_dataset", {}),
(
DUMMY_NON_STANDARD_DATASET_DIR,
"fxt_non_standard_dataset",
{"image_prefix": "image_", "mask_prefix": "gt_"},
),
],
indirect=["fxt_expected_dataset"],
ids=IDS,
)
def test_cannot_import_nested(
self,
fxt_dataset_dir: str,
fxt_expected_dataset: Dataset,
fxt_import_kwargs: Dict[str, Any],
request: pytest.FixtureRequest,
test_dir: str,
):
shutil.copytree(fxt_dataset_dir, test_dir, dirs_exist_ok=True)
subdir_name = "subdir"
subdir = os.path.join(test_dir, subdir_name)
os.makedirs(subdir)
for _file in os.listdir(test_dir):
if _file != subdir_name:
file_path = os.path.join(test_dir, _file)
shutil.move(file_path, subdir)
with pytest.raises(DatasetImportError) as exc_info:
super().test_can_import(test_dir, fxt_expected_dataset, fxt_import_kwargs, request)
assert exc_info.value.__cause__ is not None
assert isinstance(exc_info.value.__cause__, FileNotFoundError)


class CommonSemanticSegmentationWithSubsetDirsImporterTest(TestDataFormatBase):
IMPORTER = CommonSemanticSegmentationWithSubsetDirsImporter
Expand Down
Loading