diff --git a/datumaro/datumaro/components/algorithms/rise.py b/datumaro/datumaro/components/algorithms/rise.py index 78e936392c2c..8e75f10a201f 100644 --- a/datumaro/datumaro/components/algorithms/rise.py +++ b/datumaro/datumaro/components/algorithms/rise.py @@ -80,10 +80,10 @@ def normalize_hmaps(self, heatmaps, counts): def apply(self, image, progressive=False): import cv2 - assert len(image.shape) == 3, \ + assert len(image.shape) in [2, 3], \ "Expected an input image in (H, W, C) format" - assert image.shape[2] in [3, 4], \ - "Expected BGR or BGRA input" + if len(image.shape) == 3: + assert image.shape[2] in [3, 4], "Expected BGR or BGRA input" image = image[:, :, :3].astype(np.float32) model = self.model diff --git a/datumaro/datumaro/components/converters/ms_coco.py b/datumaro/datumaro/components/converters/ms_coco.py index f629f72dabc9..469184aceba4 100644 --- a/datumaro/datumaro/components/converters/ms_coco.py +++ b/datumaro/datumaro/components/converters/ms_coco.py @@ -62,7 +62,7 @@ def is_empty(self): def save_image_info(self, item, filename): if item.has_image: - h, w, _ = item.image.shape + h, w = item.image.shape[:2] else: h = 0 w = 0 @@ -187,7 +187,7 @@ def save_annotations(self, item): p.label == ann.label] if polygons: segmentation = [p.get_points() for p in polygons] - h, w, _ = item.image.shape + h, w = item.image.shape[:2] rles = mask_utils.frPyObjects(segmentation, h, w) rle = mask_utils.merge(rles) area = mask_utils.area(rle) @@ -211,7 +211,7 @@ def save_annotations(self, item): area = ann.area() if self._context._merge_polygons: - h, w, _ = item.image.shape + h, w = item.image.shape[:2] rles = mask_utils.frPyObjects(segmentation, h, w) rle = mask_utils.merge(rles) area = mask_utils.area(rle) diff --git a/datumaro/datumaro/components/converters/tfrecord.py b/datumaro/datumaro/components/converters/tfrecord.py index 447a8359dee3..7d6c5c19548a 100644 --- a/datumaro/datumaro/components/converters/tfrecord.py +++ b/datumaro/datumaro/components/converters/tfrecord.py @@ -48,7 +48,7 @@ def float_list_feature(value): if not item.has_image: raise Exception( "Failed to export dataset item '%s': item has no image" % item.id) - height, width, _ = item.image.shape + height, width = item.image.shape[:2] features.update({ 'image/height': int64_feature(height), diff --git a/datumaro/datumaro/components/converters/voc.py b/datumaro/datumaro/components/converters/voc.py index 810036786119..c296c351f3be 100644 --- a/datumaro/datumaro/components/converters/voc.py +++ b/datumaro/datumaro/components/converters/voc.py @@ -153,7 +153,9 @@ def save_subsets(self): ET.SubElement(source_elem, 'image').text = 'Unknown' if item.has_image: - h, w, c = item.image.shape + image_shape = item.image.shape + h, w = image_shape[:2] + c = 1 if len(image_shape) == 2 else image_shape[2] size_elem = ET.SubElement(root_elem, 'size') ET.SubElement(size_elem, 'width').text = str(w) ET.SubElement(size_elem, 'height').text = str(h) diff --git a/datumaro/datumaro/components/converters/yolo.py b/datumaro/datumaro/components/converters/yolo.py index 4bf746939d01..cf0d1db788f1 100644 --- a/datumaro/datumaro/components/converters/yolo.py +++ b/datumaro/datumaro/components/converters/yolo.py @@ -92,7 +92,7 @@ def __call__(self, extractor, save_dir): if not osp.exists(image_path): save_image(image_path, item.image) - height, width, _ = item.image.shape + height, width = item.image.shape[:2] yolo_annotation = '' for bbox in item.annotations: diff --git a/datumaro/datumaro/components/dataset_filter.py b/datumaro/datumaro/components/dataset_filter.py index 157720f36519..28339df098a7 100644 --- a/datumaro/datumaro/components/dataset_filter.py +++ b/datumaro/datumaro/components/dataset_filter.py @@ -43,7 +43,8 @@ def encode_item(self, item): def encode_image(cls, image): image_elem = ET.Element('image') - h, w, c = image.shape + h, w = image.shape[:2] + c = 1 if len(image.shape) == 2 else image.shape[2] ET.SubElement(image_elem, 'width').text = str(w) ET.SubElement(image_elem, 'height').text = str(h) ET.SubElement(image_elem, 'depth').text = str(c) diff --git a/datumaro/datumaro/components/importers/cvat.py b/datumaro/datumaro/components/importers/cvat.py index efdeff2963e7..6f831a7b90fa 100644 --- a/datumaro/datumaro/components/importers/cvat.py +++ b/datumaro/datumaro/components/importers/cvat.py @@ -40,7 +40,7 @@ def __call__(self, path, **extra_params): project.add_source(subset_name, { 'url': subset_path, 'format': self.EXTRACTOR_NAME, - 'options': extra_params, + 'options': dict(extra_params), }) return project diff --git a/datumaro/datumaro/components/importers/datumaro.py b/datumaro/datumaro/components/importers/datumaro.py index 9c2a162b8cc8..828208d8d204 100644 --- a/datumaro/datumaro/components/importers/datumaro.py +++ b/datumaro/datumaro/components/importers/datumaro.py @@ -40,7 +40,7 @@ def __call__(self, path, **extra_params): project.add_source(subset_name, { 'url': subset_path, 'format': self.EXTRACTOR_NAME, - 'options': extra_params, + 'options': dict(extra_params), }) return project diff --git a/datumaro/datumaro/components/importers/ms_coco.py b/datumaro/datumaro/components/importers/ms_coco.py index e7a0d26ca018..cb0fb838d2df 100644 --- a/datumaro/datumaro/components/importers/ms_coco.py +++ b/datumaro/datumaro/components/importers/ms_coco.py @@ -37,7 +37,7 @@ def __call__(self, path, **extra_params): project.add_source(source_name, { 'url': ann_file, 'format': self._COCO_EXTRACTORS[ann_type], - 'options': extra_params, + 'options': dict(extra_params), }) return project diff --git a/datumaro/datumaro/components/importers/tfrecord.py b/datumaro/datumaro/components/importers/tfrecord.py index c1506211142d..368c3d0fa9b3 100644 --- a/datumaro/datumaro/components/importers/tfrecord.py +++ b/datumaro/datumaro/components/importers/tfrecord.py @@ -35,7 +35,7 @@ def __call__(self, path, **extra_params): project.add_source(subset_name, { 'url': subset_path, 'format': self.EXTRACTOR_NAME, - 'options': extra_params, + 'options': dict(extra_params), }) return project diff --git a/datumaro/datumaro/components/importers/voc.py b/datumaro/datumaro/components/importers/voc.py index e71327893b0e..bc0409df805f 100644 --- a/datumaro/datumaro/components/importers/voc.py +++ b/datumaro/datumaro/components/importers/voc.py @@ -31,7 +31,7 @@ def __call__(self, path, **extra_params): project.add_source(task.name, { 'url': path, 'format': extractor_type, - 'options': extra_params, + 'options': dict(extra_params), }) if len(project.config.sources) == 0: @@ -69,7 +69,7 @@ def __call__(self, path, **extra_params): project.add_source(task_name, { 'url': task_dir, 'format': extractor_type, - 'options': extra_params, + 'options': dict(extra_params), }) if len(project.config.sources) == 0: diff --git a/datumaro/datumaro/components/importers/yolo.py b/datumaro/datumaro/components/importers/yolo.py index 2a22117edd23..df8f739626a1 100644 --- a/datumaro/datumaro/components/importers/yolo.py +++ b/datumaro/datumaro/components/importers/yolo.py @@ -28,7 +28,7 @@ def __call__(self, path, **extra_params): project.add_source(source_name, { 'url': config_path, 'format': 'yolo', - 'options': extra_params, + 'options': dict(extra_params), }) return project \ No newline at end of file diff --git a/datumaro/datumaro/components/project.py b/datumaro/datumaro/components/project.py index e03aad631a5d..6fc16c1533de 100644 --- a/datumaro/datumaro/components/project.py +++ b/datumaro/datumaro/components/project.py @@ -97,7 +97,7 @@ class GitWrapper: def __init__(self, config=None): self.repo = None - if config is not None: + if config is not None and osp.isdir(config.project_dir): self.init(config.project_dir) @staticmethod @@ -335,7 +335,7 @@ def __init__(self, project): own_source = None own_source_dir = osp.join(config.project_dir, config.dataset_dir) - if osp.isdir(own_source_dir): + if osp.isdir(config.project_dir) and osp.isdir(own_source_dir): log.disable(log.INFO) own_source = env.make_importer(DEFAULT_FORMAT)(own_source_dir) \ .make_dataset() diff --git a/datumaro/tests/test_project.py b/datumaro/tests/test_project.py index 7f67e1d9e18f..a66668fdcd0d 100644 --- a/datumaro/tests/test_project.py +++ b/datumaro/tests/test_project.py @@ -273,6 +273,19 @@ def __iter__(self): self.assertEqual(5, len(dataset)) + def test_can_save_and_load_own_dataset(self): + with TestDir() as test_dir: + src_project = Project() + src_dataset = src_project.make_dataset() + item = DatasetItem(id=1) + src_dataset.put(item) + src_dataset.save(test_dir.path) + + loaded_project = Project.load(test_dir.path) + loaded_dataset = loaded_project.make_dataset() + + self.assertEqual(list(src_dataset), list(loaded_dataset)) + def test_project_own_dataset_can_be_modified(self): project = Project() dataset = project.make_dataset()