From 1594990bde82aab739177c9a7bfd0202f648dc86 Mon Sep 17 00:00:00 2001 From: brimoor Date: Wed, 11 Dec 2024 00:28:45 -0500 Subject: [PATCH] support on-disk instance segmentations in SDK --- docs/source/user_guide/using_datasets.rst | 8 +- fiftyone/core/collections.py | 3 + fiftyone/core/labels.py | 7 +- fiftyone/utils/data/exporters.py | 102 +++++++++++-------- fiftyone/utils/data/importers.py | 52 ++++++---- fiftyone/utils/labels.py | 113 +++++++++++++++++----- tests/unittests/import_export_tests.py | 109 +++++++++++++++++++++ 7 files changed, 303 insertions(+), 91 deletions(-) diff --git a/docs/source/user_guide/using_datasets.rst b/docs/source/user_guide/using_datasets.rst index 729434e07e..8bab5b83a6 100644 --- a/docs/source/user_guide/using_datasets.rst +++ b/docs/source/user_guide/using_datasets.rst @@ -2542,7 +2542,7 @@ Object detections stored in |Detections| may also have instance segmentation masks. These masks can be stored in one of two ways: either directly in the database -via the :attr:`mask` attribute, or on +via the :attr:`mask ` attribute, or on disk referenced by the :attr:`mask_path ` attribute. @@ -2605,8 +2605,10 @@ object's bounding box when visualizing in the App. , }> -Like all |Label| types, you can also add custom attributes to your detections -by dynamically adding new fields to each |Detection| instance: +Like all |Label| types, you can also add custom attributes to your instance +segmentations by dynamically adding new fields to each |Detection| instance: .. code-block:: python :linenos: diff --git a/fiftyone/core/collections.py b/fiftyone/core/collections.py index d53a25fce1..8200255fb4 100644 --- a/fiftyone/core/collections.py +++ b/fiftyone/core/collections.py @@ -10681,6 +10681,9 @@ def _get_media_fields( app_media_fields.discard("filepath") for field_name, field in schema.items(): + while isinstance(field, fof.ListField): + field = field.field + if field_name in app_media_fields: media_fields[field_name] = None elif isinstance(field, fof.EmbeddedDocumentField) and issubclass( diff --git a/fiftyone/core/labels.py b/fiftyone/core/labels.py index e8b9bd9390..e6b09d8267 100644 --- a/fiftyone/core/labels.py +++ b/fiftyone/core/labels.py @@ -409,7 +409,8 @@ class Detection(_HasAttributesDict, _HasID, _HasMedia, Label): its bounding box, which should be a 2D binary or 0/1 integer numpy array mask_path (None): the absolute path to the instance segmentation image - on disk + on disk, which should be a single-channel PNG image where any + non-zero values represent the instance's extent confidence (None): a confidence in ``[0, 1]`` for the detection index (None): an index for the object attributes ({}): a dict mapping attribute names to :class:`Attribute` @@ -532,8 +533,8 @@ def to_segmentation(self, mask=None, frame_size=None, target=255): """ if not self.has_mask: raise ValueError( - "Only detections with their `mask` attributes populated can " - "be converted to segmentations" + "Only detections with their `mask` or `mask_path` attribute " + "populated can be converted to segmentations" ) mask, target = _parse_segmentation_target(mask, frame_size, target) diff --git a/fiftyone/utils/data/exporters.py b/fiftyone/utils/data/exporters.py index e2a0780380..7a9b7da68e 100644 --- a/fiftyone/utils/data/exporters.py +++ b/fiftyone/utils/data/exporters.py @@ -12,11 +12,13 @@ import warnings from collections import defaultdict +from bson import json_util +import pydash + import eta.core.datasets as etad import eta.core.frameutils as etaf import eta.core.serial as etas import eta.core.utils as etau -from bson import json_util import fiftyone as fo import fiftyone.core.collections as foc @@ -2029,34 +2031,38 @@ def _export_frame_labels(self, sample, uuid): def _export_media_fields(self, sd): for field_name, key in self._media_fields.items(): - value = sd.get(field_name, None) - if value is None: - continue - - if key is not None: - self._export_media_field(value, field_name, key=key) - else: - self._export_media_field(sd, field_name) + self._export_media_field(sd, field_name, key=key) def _export_media_field(self, d, field_name, key=None): - if key is not None: - value = d.get(key, None) - else: - key = field_name - value = d.get(field_name, None) - + value = pydash.get(d, field_name, None) if value is None: return media_exporter = self._get_media_field_exporter(field_name) - outpath, _ = media_exporter.export(value) - if self.abs_paths: - d[key] = outpath - else: - d[key] = fou.safe_relpath( - outpath, self.export_dir, default=outpath - ) + if not isinstance(value, (list, tuple)): + value = [value] + + for _d in value: + if key is not None: + _value = _d.get(key, None) + else: + _value = _d + + if _value is None: + continue + + outpath, _ = media_exporter.export(_value) + + if not self.abs_paths: + outpath = fou.safe_relpath( + outpath, self.export_dir, default=outpath + ) + + if key is not None: + _d[key] = outpath + else: + pydash.set_(d, field_name, outpath) def _get_media_field_exporter(self, field_name): media_exporter = self._media_field_exporters.get(field_name, None) @@ -2333,33 +2339,43 @@ def _prep_sample(sd): def _export_media_fields(self, sd): for field_name, key in self._media_fields.items(): - value = sd.get(field_name, None) - if value is None: - continue + self._export_media_field(sd, field_name, key=key) + + def _export_media_field(self, d, field_name, key=None): + value = pydash.get(d, field_name, None) + if value is None: + return + media_exporter = self._get_media_field_exporter(field_name) + + if not isinstance(value, (list, tuple)): + value = [value] + + for _d in value: if key is not None: - self._export_media_field(value, field_name, key=key) + _value = _d.get(key, None) else: - self._export_media_field(sd, field_name) + _value = _d - def _export_media_field(self, d, field_name, key=None): - if key is not None: - value = d.get(key, None) - else: - key = field_name - value = d.get(field_name, None) + if _value is None: + continue - if value is None: - return + if self.export_media is not False: + # Store relative path + _, uuid = media_exporter.export(_value) + outpath = os.path.join("fields", field_name, uuid) + elif self.rel_dir is not None: + # Remove `rel_dir` prefix from path + outpath = fou.safe_relpath( + _value, self.rel_dir, default=_value + ) + else: + continue - if self.export_media is not False: - # Store relative path - media_exporter = self._get_media_field_exporter(field_name) - _, uuid = media_exporter.export(value) - d[key] = os.path.join("fields", field_name, uuid) - elif self.rel_dir is not None: - # Remove `rel_dir` prefix from path - d[key] = fou.safe_relpath(value, self.rel_dir, default=value) + if key is not None: + _d[key] = outpath + else: + pydash.set_(d, field_name, outpath) def _get_media_field_exporter(self, field_name): media_exporter = self._media_field_exporters.get(field_name, None) diff --git a/fiftyone/utils/data/importers.py b/fiftyone/utils/data/importers.py index 11c50f45a5..299827f3c0 100644 --- a/fiftyone/utils/data/importers.py +++ b/fiftyone/utils/data/importers.py @@ -14,6 +14,7 @@ from bson import json_util from mongoengine.base import get_document +import pydash import eta.core.datasets as etad import eta.core.image as etai @@ -2151,32 +2152,43 @@ def _import_runs(dataset, runs, results_dir, run_cls): def _parse_media_fields(sd, media_fields, rel_dir): for field_name, key in media_fields.items(): - value = sd.get(field_name, None) + value = pydash.get(sd, field_name, None) if value is None: continue if isinstance(value, dict): - if key is False: - try: - _cls = value.get("_cls", None) - key = get_document(_cls)._MEDIA_FIELD - except Exception as e: - logger.warning( - "Failed to infer media field for '%s'. Reason: %s", - field_name, - e, - ) - key = None - - media_fields[field_name] = key - - if key is not None: - path = value.get(key, None) - if path is not None and not os.path.isabs(path): - value[key] = os.path.join(rel_dir, path) + _parse_nested_media_field( + value, media_fields, rel_dir, field_name, key + ) + elif isinstance(value, list): + for d in value: + _parse_nested_media_field( + d, media_fields, rel_dir, field_name, key + ) elif etau.is_str(value): if not os.path.isabs(value): - sd[field_name] = os.path.join(rel_dir, value) + pydash.set_(sd, field_name, os.path.join(rel_dir, value)) + + +def _parse_nested_media_field(d, media_fields, rel_dir, field_name, key): + if key is False: + try: + _cls = d.get("_cls", None) + key = get_document(_cls)._MEDIA_FIELD + except Exception as e: + logger.warning( + "Failed to infer media field for '%s'. Reason: %s", + field_name, + e, + ) + key = None + + media_fields[field_name] = key + + if key is not None: + path = d.get(key, None) + if path is not None and not os.path.isabs(path): + d[key] = os.path.join(rel_dir, path) class ImageDirectoryImporter(UnlabeledImageDatasetImporter): diff --git a/fiftyone/utils/labels.py b/fiftyone/utils/labels.py index 7071d1f001..f28bfac205 100644 --- a/fiftyone/utils/labels.py +++ b/fiftyone/utils/labels.py @@ -155,8 +155,8 @@ def export_segmentations( overwrite=False, progress=None, ): - """Exports the segmentations (or heatmaps) stored as in-database arrays in - the specified field to images on disk. + """Exports the semantic segmentations, instance segmentations, or heatmaps + stored as in-database arrays in the specified field to images on disk. Any labels without in-memory arrays are skipped. @@ -164,7 +164,9 @@ def export_segmentations( sample_collection: a :class:`fiftyone.core.collections.SampleCollection` in_field: the name of the - :class:`fiftyone.core.labels.Segmentation` or + :class:`fiftyone.core.labels.Segmentation`, + :class:`fiftyone.core.labels.Detection`, + :class:`fiftyone.core.labels.Detections`, or :class:`fiftyone.core.labels.Heatmap` field output_dir: the directory in which to write the images rel_dir (None): an optional relative directory to strip from each input @@ -183,7 +185,9 @@ def export_segmentations( """ fov.validate_non_grouped_collection(sample_collection) fov.validate_collection_label_fields( - sample_collection, in_field, (fol.Segmentation, fol.Heatmap) + sample_collection, + in_field, + (fol.Segmentation, fol.Detection, fol.Detections, fol.Heatmap), ) samples = sample_collection.select_fields(in_field) @@ -207,16 +211,31 @@ def export_segmentations( if label is None: continue - outpath = filename_maker.get_output_path( - image.filepath, output_ext=".png" - ) - - if isinstance(label, fol.Heatmap): - if label.map is not None: - label.export_map(outpath, update=update) - else: + if isinstance(label, fol.Segmentation): + if label.mask is not None: + outpath = filename_maker.get_output_path( + image.filepath, output_ext=".png" + ) + label.export_mask(outpath, update=update) + elif isinstance(label, fol.Detection): if label.mask is not None: + outpath = filename_maker.get_output_path( + image.filepath, output_ext=".png" + ) label.export_mask(outpath, update=update) + elif isinstance(label, fol.Detections): + for detection in label.detections: + if detection.mask is not None: + outpath = filename_maker.get_output_path( + image.filepath, output_ext=".png" + ) + detection.export_mask(outpath, update=update) + elif isinstance(label, fol.Heatmap): + if label.map is not None: + outpath = filename_maker.get_output_path( + image.filepath, output_ext=".png" + ) + label.export_map(outpath, update=update) def import_segmentations( @@ -226,8 +245,8 @@ def import_segmentations( delete_images=False, progress=None, ): - """Imports the segmentations (or heatmaps) stored on disk in the specified - field to in-database arrays. + """Imports the semantic segmentations, instance segmentations, or heatmaps + stored on disk in the specified field to in-database arrays. Any labels without images on disk are skipped. @@ -235,7 +254,9 @@ def import_segmentations( sample_collection: a :class:`fiftyone.core.collections.SampleCollection` in_field: the name of the - :class:`fiftyone.core.labels.Segmentation` or + :class:`fiftyone.core.labels.Segmentation`, + :class:`fiftyone.core.labels.Detection`, + :class:`fiftyone.core.labels.Detections`, or :class:`fiftyone.core.labels.Heatmap` field update (True): whether to delete the image paths from the labels delete_images (False): whether to delete any imported images from disk @@ -245,7 +266,9 @@ def import_segmentations( """ fov.validate_non_grouped_collection(sample_collection) fov.validate_collection_label_fields( - sample_collection, in_field, (fol.Segmentation, fol.Heatmap) + sample_collection, + in_field, + (fol.Segmentation, fol.Detection, fol.Detections, fol.Heatmap), ) samples = sample_collection.select_fields(in_field) @@ -262,18 +285,33 @@ def import_segmentations( if label is None: continue - if isinstance(label, fol.Heatmap): - if label.map_path is not None: - del_path = label.map_path if delete_images else None - label.import_map(update=update) + if isinstance(label, fol.Segmentation): + if label.mask_path is not None: + del_path = label.mask_path if delete_images else None + label.import_mask(update=update) if del_path: etau.delete_file(del_path) - else: + elif isinstance(label, fol.Detection): if label.mask_path is not None: del_path = label.mask_path if delete_images else None label.import_mask(update=update) if del_path: etau.delete_file(del_path) + elif isinstance(label, fol.Detections): + for detection in label.detections: + if detection.mask_path is not None: + del_path = ( + detection.mask_path if delete_images else None + ) + detection.import_mask(update=update) + if del_path: + etau.delete_file(del_path) + elif isinstance(label, fol.Heatmap): + if label.map_path is not None: + del_path = label.map_path if delete_images else None + label.import_map(update=update) + if del_path: + etau.delete_file(del_path) def transform_segmentations( @@ -389,6 +427,9 @@ def segmentations_to_detections( out_field, mask_targets=None, mask_types="stuff", + output_dir=None, + rel_dir=None, + overwrite=False, progress=None, ): """Converts the semantic segmentations masks in the specified field of the @@ -423,6 +464,18 @@ def segmentations_to_detections( - ``"thing"`` if all classes are thing classes - a dict mapping pixel values (2D masks) or RGB hex strings (3D masks) to ``"stuff"`` or ``"thing"`` for each class + output_dir (None): an optional output directory in which to write + instance segmentation images. If none is provided, the instance + segmentations are stored in the database + rel_dir (None): an optional relative directory to strip from each input + filepath to generate a unique identifier that is joined with + ``output_dir`` to generate an output path for each instance + segmentation image. This argument allows for populating nested + subdirectories in ``output_dir`` that match the shape of the input + paths. The path is converted to an absolute path (if necessary) via + :func:`fiftyone.core.storage.normalize_path` + overwrite (False): whether to delete ``output_dir`` prior to exporting + if it exists progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead @@ -438,6 +491,14 @@ def segmentations_to_detections( in_field, processing_frames = samples._handle_frame_field(in_field) out_field, _ = samples._handle_frame_field(out_field) + if overwrite and output_dir is not None: + etau.delete_dir(output_dir) + + if output_dir is not None: + filename_maker = fou.UniqueFilenameMaker( + output_dir=output_dir, rel_dir=rel_dir, idempotent=False + ) + for sample in samples.iter_samples(autosave=True, progress=progress): if processing_frames: images = sample.frames.values() @@ -449,9 +510,17 @@ def segmentations_to_detections( if label is None: continue - image[out_field] = label.to_detections( + detections = label.to_detections( mask_targets=mask_targets, mask_types=mask_types ) + if output_dir is not None: + for detection in detections.detections: + mask_path = filename_maker.get_output_path( + image.filepath, output_ext=".png" + ) + detection.export_mask(mask_path, update=True) + + image[out_field] = detections def instances_to_polylines( diff --git a/tests/unittests/import_export_tests.py b/tests/unittests/import_export_tests.py index 896429d8a7..d7f601a9e4 100644 --- a/tests/unittests/import_export_tests.py +++ b/tests/unittests/import_export_tests.py @@ -2218,6 +2218,115 @@ def _test_image_segmentation_fiftyone_dataset(self, dataset_type): dataset2.values("segmentations.mask_path"), ) + @drop_datasets + def test_instance_segmentation_fiftyone_dataset(self): + self._test_instance_segmentation_fiftyone_dataset( + fo.types.FiftyOneDataset + ) + + @drop_datasets + def test_instance_segmentation_legacy_fiftyone_dataset(self): + self._test_instance_segmentation_fiftyone_dataset( + fo.types.LegacyFiftyOneDataset + ) + + def _test_instance_segmentation_fiftyone_dataset(self, dataset_type): + dataset = self._make_dataset() + + # In-database instance segmentations + + export_dir = self._new_dir() + + dataset.export( + export_dir=export_dir, + dataset_type=dataset_type, + ) + + dataset2 = fo.Dataset.from_dir( + dataset_dir=export_dir, + dataset_type=dataset_type, + ) + + self.assertEqual(len(dataset), len(dataset2)) + self.assertEqual(dataset.count("detections.detections.mask_path"), 0) + self.assertEqual(dataset2.count("detections.detections.mask_path"), 0) + self.assertEqual( + dataset.count("detections.detections.mask"), + dataset2.count("detections.detections.mask"), + ) + + # Convert to on-disk instance segmentations + + segmentations_dir = self._new_dir() + + foul.export_segmentations(dataset, "detections", segmentations_dir) + + self.assertEqual(dataset.count("detections.detections.mask"), 0) + for mask_path in dataset.values("detections.detections[].mask_path"): + if mask_path is not None: + self.assertTrue(mask_path.startswith(segmentations_dir)) + + # On-disk instance segmentations + + export_dir = self._new_dir() + field_dir = os.path.join(export_dir, "fields", "detections.detections") + + dataset.export( + export_dir=export_dir, + dataset_type=dataset_type, + ) + + dataset2 = fo.Dataset.from_dir( + dataset_dir=export_dir, + dataset_type=dataset_type, + ) + + self.assertEqual(len(dataset), len(dataset2)) + self.assertEqual(dataset2.count("detections.detections.mask"), 0) + self.assertEqual( + dataset.count("detections.detections.mask_path"), + dataset2.count("detections.detections.mask_path"), + ) + + for mask_path in dataset2.values("detections.detections[].mask_path"): + if mask_path is not None: + self.assertTrue(mask_path.startswith(field_dir)) + + # On-disk instance segmentations (don't export media) + + export_dir = self._new_dir() + + dataset.export( + export_dir=export_dir, + dataset_type=dataset_type, + export_media=False, + ) + + dataset2 = fo.Dataset.from_dir( + dataset_dir=export_dir, + dataset_type=dataset_type, + ) + + self.assertEqual(len(dataset), len(dataset2)) + self.assertListEqual( + dataset.values("filepath"), + dataset2.values("filepath"), + ) + self.assertListEqual( + dataset.values("detections.detections[].mask_path"), + dataset2.values("detections.detections[].mask_path"), + ) + + # Convert to in-database instance segmentations + + foul.import_segmentations(dataset2, "detections") + + self.assertEqual(dataset2.count("detections.detections.mask_path"), 0) + self.assertEqual( + dataset2.count("detections.detections.mask"), + dataset.count("detections.detections.mask_path"), + ) + class DICOMDatasetTests(ImageDatasetTests): def _get_dcm_path(self):