Skip to content

Commit

Permalink
support on-disk instance segmentations in SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Dec 11, 2024
1 parent 81db334 commit 1594990
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 91 deletions.
8 changes: 5 additions & 3 deletions docs/source/user_guide/using_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2542,7 +2542,7 @@ Object detections stored in |Detections| may also have instance segmentation
masks.

These masks can be stored in one of two ways: either directly in the database
via the :attr:`mask<fiftyone.core.labels.Detection.mask>` attribute, or on
via the :attr:`mask <fiftyone.core.labels.Detection.mask>` attribute, or on
disk referenced by the
:attr:`mask_path <fiftyone.core.labels.Detection.mask_path>` attribute.

Expand Down Expand Up @@ -2605,8 +2605,10 @@ object's bounding box when visualizing in the App.
<Detection: {
'id': '5f8709282018186b6ef6682b',
'attributes': {},
'tags': [],
'label': 'cat',
'bounding_box': [0.48, 0.513, 0.397, 0.288],
'mask': None,
'mask_path': '/path/to/mask.png',
'confidence': 0.96,
'index': None,
Expand All @@ -2615,8 +2617,8 @@ object's bounding box when visualizing in the App.
}>,
}>
Like all |Label| types, you can also add custom attributes to your detections
by dynamically adding new fields to each |Detection| instance:
Like all |Label| types, you can also add custom attributes to your instance
segmentations by dynamically adding new fields to each |Detection| instance:

.. code-block:: python
:linenos:
Expand Down
3 changes: 3 additions & 0 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10681,6 +10681,9 @@ def _get_media_fields(
app_media_fields.discard("filepath")

for field_name, field in schema.items():
while isinstance(field, fof.ListField):
field = field.field

if field_name in app_media_fields:
media_fields[field_name] = None
elif isinstance(field, fof.EmbeddedDocumentField) and issubclass(
Expand Down
7 changes: 4 additions & 3 deletions fiftyone/core/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ class Detection(_HasAttributesDict, _HasID, _HasMedia, Label):
its bounding box, which should be a 2D binary or 0/1 integer numpy
array
mask_path (None): the absolute path to the instance segmentation image
on disk
on disk, which should be a single-channel PNG image where any
non-zero values represent the instance's extent
confidence (None): a confidence in ``[0, 1]`` for the detection
index (None): an index for the object
attributes ({}): a dict mapping attribute names to :class:`Attribute`
Expand Down Expand Up @@ -532,8 +533,8 @@ def to_segmentation(self, mask=None, frame_size=None, target=255):
"""
if not self.has_mask:
raise ValueError(
"Only detections with their `mask` attributes populated can "
"be converted to segmentations"
"Only detections with their `mask` or `mask_path` attribute "
"populated can be converted to segmentations"
)

mask, target = _parse_segmentation_target(mask, frame_size, target)
Expand Down
102 changes: 59 additions & 43 deletions fiftyone/utils/data/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import warnings
from collections import defaultdict

from bson import json_util
import pydash

import eta.core.datasets as etad
import eta.core.frameutils as etaf
import eta.core.serial as etas
import eta.core.utils as etau
from bson import json_util

import fiftyone as fo
import fiftyone.core.collections as foc
Expand Down Expand Up @@ -2029,34 +2031,38 @@ def _export_frame_labels(self, sample, uuid):

def _export_media_fields(self, sd):
for field_name, key in self._media_fields.items():
value = sd.get(field_name, None)
if value is None:
continue

if key is not None:
self._export_media_field(value, field_name, key=key)
else:
self._export_media_field(sd, field_name)
self._export_media_field(sd, field_name, key=key)

def _export_media_field(self, d, field_name, key=None):
if key is not None:
value = d.get(key, None)
else:
key = field_name
value = d.get(field_name, None)

value = pydash.get(d, field_name, None)
if value is None:
return

media_exporter = self._get_media_field_exporter(field_name)
outpath, _ = media_exporter.export(value)

if self.abs_paths:
d[key] = outpath
else:
d[key] = fou.safe_relpath(
outpath, self.export_dir, default=outpath
)
if not isinstance(value, (list, tuple)):
value = [value]

for _d in value:
if key is not None:
_value = _d.get(key, None)
else:
_value = _d

if _value is None:
continue

outpath, _ = media_exporter.export(_value)

if not self.abs_paths:
outpath = fou.safe_relpath(
outpath, self.export_dir, default=outpath
)

if key is not None:
_d[key] = outpath
else:
pydash.set_(d, field_name, outpath)

def _get_media_field_exporter(self, field_name):
media_exporter = self._media_field_exporters.get(field_name, None)
Expand Down Expand Up @@ -2333,33 +2339,43 @@ def _prep_sample(sd):

def _export_media_fields(self, sd):
for field_name, key in self._media_fields.items():
value = sd.get(field_name, None)
if value is None:
continue
self._export_media_field(sd, field_name, key=key)

def _export_media_field(self, d, field_name, key=None):
value = pydash.get(d, field_name, None)
if value is None:
return

media_exporter = self._get_media_field_exporter(field_name)

if not isinstance(value, (list, tuple)):
value = [value]

for _d in value:
if key is not None:
self._export_media_field(value, field_name, key=key)
_value = _d.get(key, None)
else:
self._export_media_field(sd, field_name)
_value = _d

def _export_media_field(self, d, field_name, key=None):
if key is not None:
value = d.get(key, None)
else:
key = field_name
value = d.get(field_name, None)
if _value is None:
continue

if value is None:
return
if self.export_media is not False:
# Store relative path
_, uuid = media_exporter.export(_value)
outpath = os.path.join("fields", field_name, uuid)
elif self.rel_dir is not None:
# Remove `rel_dir` prefix from path
outpath = fou.safe_relpath(
_value, self.rel_dir, default=_value
)
else:
continue

if self.export_media is not False:
# Store relative path
media_exporter = self._get_media_field_exporter(field_name)
_, uuid = media_exporter.export(value)
d[key] = os.path.join("fields", field_name, uuid)
elif self.rel_dir is not None:
# Remove `rel_dir` prefix from path
d[key] = fou.safe_relpath(value, self.rel_dir, default=value)
if key is not None:
_d[key] = outpath
else:
pydash.set_(d, field_name, outpath)

def _get_media_field_exporter(self, field_name):
media_exporter = self._media_field_exporters.get(field_name, None)
Expand Down
52 changes: 32 additions & 20 deletions fiftyone/utils/data/importers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from bson import json_util
from mongoengine.base import get_document
import pydash

import eta.core.datasets as etad
import eta.core.image as etai
Expand Down Expand Up @@ -2151,32 +2152,43 @@ def _import_runs(dataset, runs, results_dir, run_cls):

def _parse_media_fields(sd, media_fields, rel_dir):
for field_name, key in media_fields.items():
value = sd.get(field_name, None)
value = pydash.get(sd, field_name, None)
if value is None:
continue

if isinstance(value, dict):
if key is False:
try:
_cls = value.get("_cls", None)
key = get_document(_cls)._MEDIA_FIELD
except Exception as e:
logger.warning(
"Failed to infer media field for '%s'. Reason: %s",
field_name,
e,
)
key = None

media_fields[field_name] = key

if key is not None:
path = value.get(key, None)
if path is not None and not os.path.isabs(path):
value[key] = os.path.join(rel_dir, path)
_parse_nested_media_field(
value, media_fields, rel_dir, field_name, key
)
elif isinstance(value, list):
for d in value:
_parse_nested_media_field(
d, media_fields, rel_dir, field_name, key
)
elif etau.is_str(value):
if not os.path.isabs(value):
sd[field_name] = os.path.join(rel_dir, value)
pydash.set_(sd, field_name, os.path.join(rel_dir, value))


def _parse_nested_media_field(d, media_fields, rel_dir, field_name, key):
if key is False:
try:
_cls = d.get("_cls", None)
key = get_document(_cls)._MEDIA_FIELD
except Exception as e:
logger.warning(
"Failed to infer media field for '%s'. Reason: %s",
field_name,
e,
)
key = None

media_fields[field_name] = key

if key is not None:
path = d.get(key, None)
if path is not None and not os.path.isabs(path):
d[key] = os.path.join(rel_dir, path)


class ImageDirectoryImporter(UnlabeledImageDatasetImporter):
Expand Down
Loading

0 comments on commit 1594990

Please sign in to comment.