Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support on-disk instance segmentations in SDK #5256

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/source/user_guide/using_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2542,7 +2542,7 @@ Object detections stored in |Detections| may also have instance segmentation
masks.

These masks can be stored in one of two ways: either directly in the database
via the :attr:`mask<fiftyone.core.labels.Detection.mask>` attribute, or on
via the :attr:`mask <fiftyone.core.labels.Detection.mask>` attribute, or on
disk referenced by the
:attr:`mask_path <fiftyone.core.labels.Detection.mask_path>` attribute.

Expand Down Expand Up @@ -2605,8 +2605,10 @@ object's bounding box when visualizing in the App.
<Detection: {
'id': '5f8709282018186b6ef6682b',
'attributes': {},
'tags': [],
'label': 'cat',
'bounding_box': [0.48, 0.513, 0.397, 0.288],
'mask': None,
'mask_path': '/path/to/mask.png',
'confidence': 0.96,
'index': None,
Expand All @@ -2615,8 +2617,8 @@ object's bounding box when visualizing in the App.
}>,
}>

Like all |Label| types, you can also add custom attributes to your detections
by dynamically adding new fields to each |Detection| instance:
Like all |Label| types, you can also add custom attributes to your instance
segmentations by dynamically adding new fields to each |Detection| instance:

.. code-block:: python
:linenos:
Expand Down
46 changes: 33 additions & 13 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10662,9 +10662,7 @@ def _handle_db_fields(self, paths, frames=False):
db_fields_map = self._get_db_fields_map(frames=frames)
return [db_fields_map.get(p, p) for p in paths]

def _get_media_fields(
self, include_filepath=True, whitelist=None, frames=False
):
def _get_media_fields(self, whitelist=None, blacklist=None, frames=False):
media_fields = {}
Comment on lines +10665 to 10666
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Breaking change: Method signature updated

The _get_media_fields method signature has changed from include_filepath to whitelist/blacklist parameters. This is a breaking change that may affect existing code that calls this method.

Consider:

  1. Adding a deprecation warning for any code using the old signature
  2. Updating the documentation to highlight this breaking change
  3. Providing migration guidance for users


if frames:
Expand All @@ -10674,13 +10672,13 @@ def _get_media_fields(
schema = self.get_field_schema(flat=True)
app_media_fields = set(self._dataset.app_config.media_fields)

if include_filepath:
# 'filepath' should already be in set, but add it just in case
app_media_fields.add("filepath")
else:
app_media_fields.discard("filepath")
# 'filepath' should already be in set, but add it just in case
app_media_fields.add("filepath")

for field_name, field in schema.items():
while isinstance(field, fof.ListField):
field = field.field

if field_name in app_media_fields:
media_fields[field_name] = None
elif isinstance(field, fof.EmbeddedDocumentField) and issubclass(
Expand All @@ -10695,14 +10693,28 @@ def _get_media_fields(
whitelist = {whitelist}

media_fields = {
k: v for k, v in media_fields.items() if k in whitelist
k: v
for k, v in media_fields.items()
if any(w == k or k.startswith(w + ".") for w in whitelist)
}

if blacklist is not None:
if etau.is_container(blacklist):
blacklist = set(blacklist)
else:
blacklist = {blacklist}

media_fields = {
k: v
for k, v in media_fields.items()
if not any(w == k or k.startswith(w + ".") for w in blacklist)
}

return media_fields

def _resolve_media_field(self, media_field):
def _parse_media_field(self, media_field):
if media_field in self._dataset.app_config.media_fields:
return media_field
return media_field, None

_media_field, is_frame_field = self._handle_frame_field(media_field)

Expand All @@ -10711,12 +10723,20 @@ def _resolve_media_field(self, media_field):
if leaf is not None:
leaf = root + "." + leaf

if _media_field in (root, leaf):
if _media_field in (root, leaf) or root.startswith(
_media_field + "."
):
_resolved_field = leaf if leaf is not None else root
if is_frame_field:
_resolved_field = self._FRAMES_PREFIX + _resolved_field

return _resolved_field
_list_fields = self._parse_field_name(
_resolved_field, auto_unwind=False
)[-2]
if _list_fields:
return _resolved_field, _list_fields[0]

return _resolved_field, None

raise ValueError("'%s' is not a valid media field" % media_field)

Expand Down
7 changes: 4 additions & 3 deletions fiftyone/core/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ class Detection(_HasAttributesDict, _HasID, _HasMedia, Label):
its bounding box, which should be a 2D binary or 0/1 integer numpy
array
mask_path (None): the absolute path to the instance segmentation image
on disk
on disk, which should be a single-channel PNG image where any
non-zero values represent the instance's extent
Comment on lines +412 to +413
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the app technically doesn't mind multi-channel, too, but it makes sense that we write an imperative statement about masks being single-channel for clarity.

if the app runs into multi-channel pngs for masks, it uses just the first channel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. I mostly copied this verbatim from the user guide documentation you had added:

We recommend storing masks as single-channel PNG images, where a pixel value of 0 indicates the background (rendered as transparent in the App), and any other value indicates the object.

But I was not certain whether single-channel was indeed a recommendation or a hard requirement, so the version here came out sounding more imperative just to be safe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to clarify both instances of this documentation if you want!

confidence (None): a confidence in ``[0, 1]`` for the detection
index (None): an index for the object
attributes ({}): a dict mapping attribute names to :class:`Attribute`
Expand Down Expand Up @@ -532,8 +533,8 @@ def to_segmentation(self, mask=None, frame_size=None, target=255):
"""
if not self.has_mask:
raise ValueError(
"Only detections with their `mask` attributes populated can "
"be converted to segmentations"
"Only detections with their `mask` or `mask_path` attribute "
"populated can be converted to segmentations"
)

mask, target = _parse_segmentation_target(mask, frame_size, target)
Expand Down
106 changes: 61 additions & 45 deletions fiftyone/utils/data/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import warnings
from collections import defaultdict

from bson import json_util
import pydash

import eta.core.datasets as etad
import eta.core.frameutils as etaf
import eta.core.serial as etas
import eta.core.utils as etau
from bson import json_util

import fiftyone as fo
import fiftyone.core.collections as foc
Expand Down Expand Up @@ -1892,7 +1894,7 @@ def log_collection(self, sample_collection):
self._metadata["frame_fields"] = schema

self._media_fields = sample_collection._get_media_fields(
include_filepath=False
blacklist="filepath",
)

info = dict(sample_collection.info)
Expand Down Expand Up @@ -2029,34 +2031,38 @@ def _export_frame_labels(self, sample, uuid):

def _export_media_fields(self, sd):
for field_name, key in self._media_fields.items():
value = sd.get(field_name, None)
if value is None:
continue

if key is not None:
self._export_media_field(value, field_name, key=key)
else:
self._export_media_field(sd, field_name)
self._export_media_field(sd, field_name, key=key)

def _export_media_field(self, d, field_name, key=None):
if key is not None:
value = d.get(key, None)
else:
key = field_name
value = d.get(field_name, None)

value = pydash.get(d, field_name, None)
if value is None:
return

media_exporter = self._get_media_field_exporter(field_name)
outpath, _ = media_exporter.export(value)

if self.abs_paths:
d[key] = outpath
else:
d[key] = fou.safe_relpath(
outpath, self.export_dir, default=outpath
)
if not isinstance(value, (list, tuple)):
value = [value]

for _d in value:
if key is not None:
_value = _d.get(key, None)
else:
_value = _d

if _value is None:
continue

outpath, _ = media_exporter.export(_value)

if not self.abs_paths:
outpath = fou.safe_relpath(
outpath, self.export_dir, default=outpath
)

if key is not None:
_d[key] = outpath
else:
pydash.set_(d, field_name, outpath)

def _get_media_field_exporter(self, field_name):
media_exporter = self._media_field_exporters.get(field_name, None)
Expand Down Expand Up @@ -2196,7 +2202,7 @@ def export_samples(self, sample_collection, progress=None):
_sample_collection = sample_collection

self._media_fields = sample_collection._get_media_fields(
include_filepath=False
blacklist="filepath"
)

logger.info("Exporting samples...")
Expand Down Expand Up @@ -2333,33 +2339,43 @@ def _prep_sample(sd):

def _export_media_fields(self, sd):
for field_name, key in self._media_fields.items():
value = sd.get(field_name, None)
if value is None:
continue
self._export_media_field(sd, field_name, key=key)

def _export_media_field(self, d, field_name, key=None):
value = pydash.get(d, field_name, None)
if value is None:
return

media_exporter = self._get_media_field_exporter(field_name)

if not isinstance(value, (list, tuple)):
value = [value]

for _d in value:
if key is not None:
self._export_media_field(value, field_name, key=key)
_value = _d.get(key, None)
else:
self._export_media_field(sd, field_name)
_value = _d

def _export_media_field(self, d, field_name, key=None):
if key is not None:
value = d.get(key, None)
else:
key = field_name
value = d.get(field_name, None)
if _value is None:
continue

if value is None:
return
if self.export_media is not False:
# Store relative path
_, uuid = media_exporter.export(_value)
outpath = os.path.join("fields", field_name, uuid)
elif self.rel_dir is not None:
# Remove `rel_dir` prefix from path
outpath = fou.safe_relpath(
_value, self.rel_dir, default=_value
)
else:
continue

if self.export_media is not False:
# Store relative path
media_exporter = self._get_media_field_exporter(field_name)
_, uuid = media_exporter.export(value)
d[key] = os.path.join("fields", field_name, uuid)
elif self.rel_dir is not None:
# Remove `rel_dir` prefix from path
d[key] = fou.safe_relpath(value, self.rel_dir, default=value)
if key is not None:
_d[key] = outpath
else:
pydash.set_(d, field_name, outpath)

def _get_media_field_exporter(self, field_name):
media_exporter = self._media_field_exporters.get(field_name, None)
Expand Down
52 changes: 32 additions & 20 deletions fiftyone/utils/data/importers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from bson import json_util
from mongoengine.base import get_document
import pydash

import eta.core.datasets as etad
import eta.core.image as etai
Expand Down Expand Up @@ -2151,32 +2152,43 @@ def _import_runs(dataset, runs, results_dir, run_cls):

def _parse_media_fields(sd, media_fields, rel_dir):
for field_name, key in media_fields.items():
value = sd.get(field_name, None)
value = pydash.get(sd, field_name, None)
if value is None:
continue

if isinstance(value, dict):
if key is False:
try:
_cls = value.get("_cls", None)
key = get_document(_cls)._MEDIA_FIELD
except Exception as e:
logger.warning(
"Failed to infer media field for '%s'. Reason: %s",
field_name,
e,
)
key = None

media_fields[field_name] = key

if key is not None:
path = value.get(key, None)
if path is not None and not os.path.isabs(path):
value[key] = os.path.join(rel_dir, path)
_parse_nested_media_field(
value, media_fields, rel_dir, field_name, key
)
elif isinstance(value, list):
for d in value:
_parse_nested_media_field(
d, media_fields, rel_dir, field_name, key
)
elif etau.is_str(value):
if not os.path.isabs(value):
sd[field_name] = os.path.join(rel_dir, value)
pydash.set_(sd, field_name, os.path.join(rel_dir, value))


def _parse_nested_media_field(d, media_fields, rel_dir, field_name, key):
if key is False:
try:
_cls = d.get("_cls", None)
key = get_document(_cls)._MEDIA_FIELD
except Exception as e:
logger.warning(
"Failed to infer media field for '%s'. Reason: %s",
field_name,
e,
)
key = None

media_fields[field_name] = key

if key is not None:
path = d.get(key, None)
if path is not None and not os.path.isabs(path):
d[key] = os.path.join(rel_dir, path)


class ImageDirectoryImporter(UnlabeledImageDatasetImporter):
Expand Down
Loading
Loading