diff --git a/fiftyone/core/labels.py b/fiftyone/core/labels.py index 1b345253d3..b91f110e07 100644 --- a/fiftyone/core/labels.py +++ b/fiftyone/core/labels.py @@ -8,6 +8,7 @@ from functools import partial import itertools import warnings +import pathlib from bson import ObjectId import cv2 @@ -15,6 +16,7 @@ import scipy.ndimage as spn import skimage.measure as skm import skimage.segmentation as sks +import tifffile import eta.core.frameutils as etaf import eta.core.image as etai @@ -447,7 +449,9 @@ def to_polyline(self, tolerance=2, filled=True): **attributes, ) - def to_segmentation(self, mask=None, frame_size=None, target=255): + def to_segmentation( + self, mask=None, frame_size=None, target=255, panoptic=False + ): """Returns a :class:`Segmentation` representation of this instance. The detection must have an instance mask, i.e., its :attr:`mask` @@ -463,6 +467,7 @@ def to_segmentation(self, mask=None, frame_size=None, target=255): provided target (255): the pixel value or RGB hex string to use to render the object + panoptic (False): whether to return a panoptic segmentation Returns: a :class:`Segmentation` @@ -473,9 +478,14 @@ def to_segmentation(self, mask=None, frame_size=None, target=255): "be converted to segmentations" ) - mask, target = _parse_segmentation_target(mask, frame_size, target) + mask, target = _parse_segmentation_target( + mask, frame_size, target, panoptic + ) _render_instance(mask, self, target) - return Segmentation(mask=mask) + if panoptic: + instance_mask = (mask == target).astype(np.uint8) + mask = np.stack([mask, instance_mask], axis=-1) + return Segmentation(mask=mask, is_panoptic=panoptic) def to_shapely(self, frame_size=None): """Returns a Shapely representation of this instance. @@ -557,7 +567,9 @@ def to_polylines(self, tolerance=2, filled=True): ] ) - def to_segmentation(self, mask=None, frame_size=None, mask_targets=None): + def to_segmentation( + self, mask=None, frame_size=None, mask_targets=None, panoptic=False + ): """Returns a :class:`Segmentation` representation of this instance. Only detections with instance masks (i.e., their :attr:`mask` @@ -576,16 +588,19 @@ def to_segmentation(self, mask=None, frame_size=None, mask_targets=None): object classes to render and which pixel values to use for each class. If omitted, all objects are rendered with pixel value 255 + panoptic (False): whether to return a panoptic segmentation Returns: a :class:`Segmentation` """ mask, labels_to_targets = _parse_segmentation_mask_targets( - mask, frame_size, mask_targets + mask, frame_size, mask_targets, panoptic ) + if panoptic: + instance_mask = mask.copy() # pylint: disable=not-an-iterable - for detection in self.detections: + for idx, detection in enumerate(self.detections): if detection.mask is None: msg = "Skipping detection(s) with no instance mask" warnings.warn(msg) @@ -600,7 +615,13 @@ def to_segmentation(self, mask=None, frame_size=None, mask_targets=None): _render_instance(mask, detection, target) - return Segmentation(mask=mask) + if panoptic: + _render_instance(instance_mask, detection, idx + 1) + + if panoptic: + mask = np.stack([mask, instance_mask], axis=-1) + + return Segmentation(mask=mask, is_panoptic=panoptic) class Polyline(_HasAttributesDict, _HasID, Label): @@ -680,7 +701,12 @@ def to_detection(self, mask_size=None, frame_size=None): ) def to_segmentation( - self, mask=None, frame_size=None, target=255, thickness=1 + self, + mask=None, + frame_size=None, + target=255, + thickness=1, + panoptic=False, ): """Returns a :class:`Segmentation` representation of this instance. @@ -696,12 +722,18 @@ def to_segmentation( the object thickness (1): the thickness, in pixels, at which to render (non-filled) polylines + panoptic (False): whether to return a panoptic segmentation Returns: a :class:`Segmentation` """ - mask, target = _parse_segmentation_target(mask, frame_size, target) + mask, target = _parse_segmentation_target( + mask, frame_size, target, panoptic + ) _render_polyline(mask, self, target, thickness) + if panoptic: + instance_mask = (mask == target).astype(np.uint8) + mask = np.stack([mask, instance_mask], axis=-1) return Segmentation(mask=mask) def to_shapely(self, frame_size=None, filled=None): @@ -907,7 +939,12 @@ def to_detections(self, mask_size=None, frame_size=None): ) def to_segmentation( - self, mask=None, frame_size=None, mask_targets=None, thickness=1 + self, + mask=None, + frame_size=None, + mask_targets=None, + thickness=1, + panoptic=False, ): """Returns a :class:`Segmentation` representation of this instance. @@ -926,16 +963,19 @@ def to_segmentation( 255 thickness (1): the thickness, in pixels, at which to render (non-filled) polylines + panoptic (False): whether to return a panoptic segmentation Returns: a :class:`Segmentation` """ mask, labels_to_targets = _parse_segmentation_mask_targets( - mask, frame_size, mask_targets + mask, frame_size, mask_targets, panoptic ) + if panoptic: + instance_mask = mask.copy() # pylint: disable=not-an-iterable - for polyline in self.polylines: + for idx, polyline in enumerate(self.polylines): if labels_to_targets is not None: target = labels_to_targets.get(polyline.label, None) if target is None: @@ -945,6 +985,12 @@ def to_segmentation( _render_polyline(mask, polyline, target, thickness) + if panoptic: + _render_polyline(instance_mask, polyline, idx + 1, thickness) + + if panoptic: + mask = np.stack([mask, instance_mask], axis=-1) + return Segmentation(mask=mask) @@ -1004,21 +1050,36 @@ class _HasMedia(object): class Segmentation(_HasID, _HasMedia, Label): - """A semantic segmentation for an image. + """A segmentation for an image. Provide either the ``mask`` or ``mask_path`` argument to define the segmentation. + Set ``is_panoptic`` to ``True`` for panoptic segmentations, which + encode both class and instance labels. Otherwise this object + represents a semantic segmentation, which only encodes class + labels. + + If ``is_panoptic`` is ``True``, ``mask`` must have two channels. + The first channel encodes the class, and the second the instance. + Grayscale (one channel) and RGB (three channel) masks are only + supported for semantic segmentation. + + ``mask_path`` may be a file in TIFF or BigTIFF format, to allow + enough bit depth to represent large numbers of objects. They must use + the ".tif" or ".tiff" extension. + Args: - mask (None): a numpy array with integer values encoding the semantic - labels + mask (None): a numpy array with integer values encoding the labels mask_path (None): the absolute path to the segmentation image on disk + """ _MEDIA_FIELD = "mask_path" mask = fof.ArrayField() mask_path = fof.StringField() + is_panoptic = fof.BooleanField() @property def has_mask(self): @@ -1035,7 +1096,7 @@ def get_mask(self): return self.mask if self.mask_path is not None: - return _read_mask(self.mask_path) + return _read_mask(self.mask_path, self.is_panoptic) return None @@ -1049,8 +1110,7 @@ def import_mask(self, update=False): attribute after importing """ if self.mask_path is not None: - self.mask = _read_mask(self.mask_path) - + self.mask = _read_mask(self.mask_path, self.is_panoptic) if update: self.mask_path = None @@ -1083,6 +1143,9 @@ def transform_mask(self, targets_map, outpath=None, update=False): Note that any pixel values not in ``targets_map`` will be zero in the transformed mask. + If this is a panoptic segmentation, only the class labels will + be transformed. + Args: targets_map: a dict mapping existing pixel values (2D masks) or RGB hex strings (3D masks) to new pixel values or RGB hex strings. @@ -1094,12 +1157,16 @@ def transform_mask(self, targets_map, outpath=None, update=False): Returns: the transformed mask + """ mask = self.get_mask() if mask is None: return - mask = _transform_mask(mask, targets_map) + if self.is_panoptic: + mask[:, :, 0] = _transform_mask(mask[:, :, 0], targets_map) + else: + mask = _transform_mask(mask, targets_map) if outpath is not None: _write_mask(mask, outpath) @@ -1115,7 +1182,11 @@ def transform_mask(self, targets_map, outpath=None, update=False): return mask - def to_detections(self, mask_targets=None, mask_types="stuff"): + def to_detections( + self, + mask_targets=None, + mask_types=None, + ): """Returns a :class:`Detections` representation of this instance with instance masks populated. @@ -1125,6 +1196,11 @@ def to_detections(self, mask_targets=None, mask_types="stuff"): Each ``"thing"`` class will result in one :class:`Detection` instance per connected region of that class in the segmentation. + For panoptic segmentations, instances with a "thing" class + each get a separate detection. Any pixels with that class but + without an instance label are processed according to connected + regions. + Args: mask_targets (None): a dict mapping integer pixel values (2D masks) or RGB hex strings (3D masks) to label strings defining which @@ -1135,14 +1211,28 @@ def to_detections(self, mask_targets=None, mask_types="stuff"): regions, each representing an instance of the thing). Can be any of the following: - - ``"stuff"`` if all classes are stuff classes - - ``"thing"`` if all classes are thing classes + - ``"stuff"`` if all classes are stuff + classes. Ignores panoptic intances. + - ``"thing"`` if all classes are thing classes. Panoptic + instances are handled first, then any remaining labels + are separated into connected components. + - ``"panoptic"`` if all panoptic instances are thing classes + and all remaining are stuff classes. Panoptic only. + - ``"object"``: only keep panoptic object instances - a dict mapping pixel values (2D masks) or RGB hex strings - (3D masks) to ``"stuff"`` or ``"thing"`` for each class + (3D masks) to ``"stuff"``, ``"thing"``, ``"panoptic"``, + or ``"object"`` for each class + + Defaults to ``"stuff"`` for semantic segmentations and + ``"panoptic"`` for panoptic segmentations. Returns: a :class:`Detections` + """ + if mask_types is None: + mask_types = "panoptic" if self.is_panoptic else "stuff" + detections = _segmentation_to_detections( self, mask_targets, mask_types ) @@ -1151,7 +1241,7 @@ def to_detections(self, mask_targets=None, mask_types="stuff"): def to_polylines( self, mask_targets=None, - mask_types="stuff", + mask_types=None, tolerance=2, ): """Returns a :class:`Polylines` representation of this instance. @@ -1167,26 +1257,92 @@ def to_polylines( or RGB hex strings (3D masks) to label strings defining which classes to generate detections for. If omitted, all labels are assigned to their pixel values - mask_types ("stuff"): whether the classes are ``"stuff"`` + mask_types (None): whether the classes are ``"stuff"`` (amorphous regions of pixels) or ``"thing"`` (connected regions, each representing an instance of the thing). Can be any of the following: - ``"stuff"`` if all classes are stuff classes - - ``"thing"`` if all classes are thing classes + - ``"thing"`` if all classes are thing classes. Panoptic + instances are handled first, then any remaining labels + are separated into connected components. + - ``"panoptic"`` if all instances are thing classes and all + remaining are stuff classes. + - ``"object"``: only keep panoptic object instances - a dict mapping pixel values (2D masks) or RGB hex strings - (3D masks) to ``"stuff"`` or ``"thing"`` for each class + (3D masks) to ``"stuff"``, ``"thing"``, ``"object"``, + or ``"panoptic"`` for each class + + Defaults to ``"stuff"`` for semantic segmentations and + ``"panoptic"`` for panoptic segmentations. + tolerance (2): a tolerance, in pixels, when generating approximate polylines for each region. Typical values are 1-3 pixels Returns: a :class:`Polylines` + """ + if mask_types is None: + mask_types = "panoptic" if self.is_panoptic else "stuff" + polylines = _segmentation_to_polylines( self, mask_targets, mask_types, tolerance ) return Polylines(polylines=polylines) + def to_semantic(self, to_rgb=False): + """Convert panoptic segmentation to semantic. + + Optionally converts the class integer values back to 3D pixel + values. + + Args: + to_rgb (False): convert integer classes to RGB hex values + + Returns: + a :class:`Segmentation` + + """ + mask = self.get_mask() + if self.is_panoptic: + mask = mask[..., 0] + if to_rgb: + mask = _int_array_to_rgb(mask) + + return Segmentation(mask=mask, is_panoptic=False) + + def to_panoptic(self, mask_types="thing"): + """Convert semantic segmentation to panoptic. + + Each ``"stuff"`` class will be kept as-is. + + Each ``"thing"`` class will result in one instance per + connected region of that class. + + If the mask is 3D, its pixel values will be converted to the + integer value of the equivalend hex string. + + Args: + mask_types ("thing"): whether the classes are ``"stuff"`` + (amorphous regions of pixels) or ``"thing"`` (connected + regions, each representing an instance of the thing). Can be + any of the following: + + - ``"stuff"`` if all classes are stuff classes + - ``"thing"`` if all classes are thing classes + - a dict mapping pixel values (2D masks) or RGB hex strings + (3D masks) to ``"stuff"`` or ``"thing"`` for each class + + Returns: + a :class:`Segmentation` + + """ + mask = self.get_mask() + if not self.is_panoptic: + mask = _mask_to_panoptic(mask, mask_types) + return Segmentation(mask=mask, is_panoptic=True) + class Heatmap(_HasID, _HasMedia, Label): """A heatmap for an image. @@ -1480,14 +1636,85 @@ def from_geo_json(cls, d): } -def _read_mask(mask_path): +@staticmethod +def _get_uint_dtype(maxval): + if maxval < 2**8: + return np.uint8 + elif maxval < 2**16: + return np.uint16 + elif maxval < 2**32: + return np.uint32 + elif maxval < 2**64: + return np.uint64 + + raise ValueError(f"max value of {maxval} exceeds upper limit of 2^64") + + +def _mask_to_image(mask): + shape_is_valid = False + if mask.ndim == 2: + shape_is_valid = True + elif mask.ndim == 3: + if mask.shape[-1] in (1, 2, 3): + shape_is_valid = True + + if not shape_is_valid: + raise ValueError("unsupported detection mask shape: {mask.shape}") + + allowed_types = (np.uint8, np.uint16, np.uint32, np.uint64) + + if mask.dtype not in allowed_types: + raise ValueError(f"unsupported detection mask dtype: {mask.dtype}") + + # cast to smallest dtype possible + maxval = mask.max() + if mask.ndim == 3 and mask.shape[-1] == 3 and maxval >= 2**16: + raise ValueError( + f"3-channel masks must be saved as a PNG, with a max bit" + f" dpeth of 16, but mask has max value of {maxval} >= 2**16" + ) + + dtype = _get_uint_dtype(maxval) + return mask.astype(dtype) + + +def _get_extension(path): + extension = pathlib.Path(path).suffix + return extension + + +def _read_mask(mask_path, panoptic=False): # pylint: disable=no-member - return foui.read(mask_path, flag=cv2.IMREAD_UNCHANGED) + extension = _get_extension(mask_path) + if extension in (".tif", ".tiff"): + mask = tifffile.imread(mask_path) + else: + mask = foui.read(mask_path, flag=cv2.IMREAD_UNCHANGED) + + # do this here even though we're not writing a mask because it + # converts to smallest possible dtype, and also checks type + # and bounds + mask = _mask_to_image(mask) + + if panoptic and mask.ndim == 3 and mask.shape[-1] == 3: + mask = mask[..., 0:2] + + return mask def _write_mask(mask, mask_path): + extension = _get_extension(mask_path) mask = _mask_to_image(mask) - foui.write(mask, mask_path) + if extension in (".tif", ".tiff"): + bigtiff = mask.dtype == np.uint64 + tifffile.imwrite(mask_path, mask, compression="zlib", bigtiff=bigtiff) + else: + if mask.ndim == 3 and mask.shape[-1] == 2: + # add empty third channel + mask = np.dstack( + (mask, np.zeros(mask.shape[:2], dtype=mask.dtype)) + ) + foui.write(mask, mask_path) def _transform_mask(in_mask, targets_map): @@ -1531,18 +1758,6 @@ def _transform_mask(in_mask, targets_map): return out_mask -def _mask_to_image(mask): - if mask.dtype in (np.uint8, np.uint16): - return mask - - # Masks should contain integer values, so cast to the closest suitable - # unsigned type - if mask.max() <= 255: - return mask.astype(np.uint8) - - return mask.astype(np.uint16) - - def _read_heatmap(map_path): # pylint: disable=no-member return foui.read(map_path, flag=cv2.IMREAD_UNCHANGED) @@ -1572,12 +1787,15 @@ def _heatmap_to_image(map, range): return map.astype(np.uint8) -def _parse_segmentation_target(mask, frame_size, target): +def _parse_segmentation_target(mask, frame_size, target, panoptic): if target is not None: is_rgb = fof.is_rgb_target(target) else: is_rgb = False + if panoptic and is_rgb: + raise ValueError("panoptic segmentation cannot have RGB mask") + if mask is None: if frame_size is None: raise ValueError("Either `mask` or `frame_size` must be provided") @@ -1599,12 +1817,15 @@ def _parse_segmentation_target(mask, frame_size, target): return mask, target -def _parse_segmentation_mask_targets(mask, frame_size, mask_targets): +def _parse_segmentation_mask_targets(mask, frame_size, mask_targets, panoptic): if mask_targets is not None: is_rgb = fof.is_rgb_mask_targets(mask_targets) else: is_rgb = False + if panoptic and is_rgb: + raise ValueError("panoptic segmentation cannot have RGB mask") + if mask is None: if frame_size is None: raise ValueError("Either `mask` or `frame_size` must be provided") @@ -1676,6 +1897,80 @@ def _find_slices(mask): return dict((backward[idx + 1], slc) for idx, slc in enumerate(slices)) +def _get_label_and_type(target, mask_targets, mask_types, default): + if mask_targets is not None: + label = mask_targets.get(target, None) + else: + label = str(target) + + label_type = mask_types.get(target, None) + if label_type is None: + label_type = default + + return label, label_type + + +def _mask_is_rgb(mask): + return mask.ndim == 3 and mask.shape[-1] == 3 + + +def _mask_to_panoptic(mask, mask_types): + """Convert semantic segmentation to panoptic segmentation.""" + if isinstance(mask_types, dict): + default = None + else: + default = mask_types + mask_types = {} + + mask = mask.squeeze() + if _mask_is_rgb(mask): + mask = _rgb_array_to_int(mask) + if mask_types is not None: + mask_types = {_hex_to_int(k): v for k, v in mask_types.items()} + + panoptic_class_mask = np.zeros(mask.shape[:2], dtype=np.uint64) + panoptic_instance_mask = np.zeros(mask.shape[:2], dtype=np.uint64) + + # ensure instance indices are unique across multiple classes + total_instances = 0 + + class_objects = _find_slices(mask) + for class_target, class_slices in class_objects.items(): + label, label_type = _get_label_and_type( + class_target, None, mask_types, default + ) + if label is None or label_type is None: + continue # skip unknown class_target + + class_mask = mask[class_slices] == class_target + panoptic_class_mask[class_slices][class_mask] = class_target + + if label_type == "stuff": + pass + elif label_type == "thing": + labeled = skm.label(class_mask) + instance_objects = _find_slices(labeled) + + for instance_target, instance_slices in instance_objects.items(): + instance_mask = labeled[instance_slices] == instance_target + panoptic_instance_mask[class_slices][instance_slices][ + instance_mask + ] = (total_instances + instance_target) + + total_instances += len(instance_objects) + + else: + raise ValueError( + "Unsupported mask type '%s'. Supported values are " + "('stuff', 'thing')" + ) + + panoptic_mask = np.stack( + [panoptic_class_mask, panoptic_instance_mask], axis=-1 + ) + return panoptic_mask + + def _convert_segmentation(segmentation, mask_targets, mask_types, converter): """Convert segmentation to a collection of detections, polylines, etc. @@ -1691,41 +1986,86 @@ def _convert_segmentation(segmentation, mask_targets, mask_types, converter): mask_types = {} mask = segmentation.get_mask() - is_rgb = mask.ndim == 3 + mask = mask.squeeze() - if is_rgb: + if _mask_is_rgb(mask): # convert to int, like in transform_mask mask = _rgb_array_to_int(mask) if mask_targets is not None: mask_targets = {_hex_to_int(k): v for k, v in mask_targets.items()} - mask = mask.squeeze() + results = [] + if segmentation.is_panoptic: + # handle all individual instances, zeroing them out as they + # are added. any remaining segmentation pixels are handled + # like semantic segmentations. + + if mask.ndim != 3 or mask.shape[-1] != 2: + raise ValueError(f"Unsupported panoptic mask shape: {mask.ndim}") + + instances_mask = mask[:, :, 1] + mask = mask[:, :, 0].copy() # class mask. copy so we can modify it. + + instance_objects = _find_slices(instances_mask) + for idx, slices in instance_objects.items(): + class_targets = list( + set(mask[slices][instances_mask[slices] == idx]) + ) + if len(class_targets) != 1: + raise ValueError( + f"instance {idx} has multiple segmentation classes" + ) + class_target = class_targets[0] + if class_target == 0: + raise ValueError(f"instance {idx} has no segmentation class") + + label, label_type = _get_label_and_type( + class_target, mask_targets, mask_types, default + ) + if label is None or label_type is None: + continue # skip unknown target + + if label_type == "stuff": + continue # leave for semantic segmentation + + label_mask = instances_mask[slices] == idx + offset = list(s.start for s in slices)[::-1] + frame_size = mask.shape[:2][::-1] + new_results = converter( + label_mask, label, label_type, offset, frame_size + ) + results.extend(new_results) + + # zero out this object in semantic segmentation + mask[slices][label_mask] = 0 + + # continue with semantic segmentation of remaining labels if mask.ndim != 2: raise ValueError(f"Unsupported mask dimensions: {mask.ndim}") objects = _find_slices(mask) - results = [] for target, slices in objects.items(): - if mask_targets is not None: - label = mask_targets.get(target, None) - - if label is None: - continue # skip unknown target - else: - label = str(target) - - label_type = mask_types.get(target, None) + label, label_type = _get_label_and_type( + target, mask_targets, mask_types, default + ) + if label is None or label_type is None: + continue # skip unknown target - if label_type is None: - if default is None: - continue # skip unknown type + if ( + label_type in ("panoptic", "object") + and not segmentation.is_panoptic + ): + raise ValueError( + f"Unsupported label type for semantic segmentation: {label_type}" + ) - label_type = default + if label_type == "object": + # skip semantic segmentations for this class + continue label_mask = mask[slices] == target offset = list(s.start for s in slices)[::-1] frame_size = mask.shape[:2][::-1] - new_results = converter( label_mask, label, label_type, offset, frame_size ) @@ -1735,14 +2075,14 @@ def _convert_segmentation(segmentation, mask_targets, mask_types, converter): def _mask_to_detections(label_mask, label, label_type, offset, frame_size): - if label_type == "stuff": + if label_type in ("stuff", "panoptic", "object"): instances = [_parse_stuff_instance(label_mask, offset, frame_size)] elif label_type == "thing": instances = _parse_thing_instances(label_mask, offset, frame_size) else: raise ValueError( - "Unsupported mask type '%s'. Supported values are " - "('stuff', 'thing')" + f"Unsupported mask type '{label_type}'. Supported values are " + "('stuff', 'thing', 'panoptic', 'object')" ) return list( @@ -1761,14 +2101,14 @@ def _mask_to_polylines( frame_size=frame_size, ) - if label_type == "stuff": + if label_type in ("stuff", "panoptic", "object"): polygons = [polygons] elif label_type == "thing": polygons = [[p] for p in polygons] else: raise ValueError( - "Unsupported mask type '%s'. Supported values are " - "('stuff', 'thing')" + f"Unsupported mask type '{label_type}'. Supported values are " + "('stuff', 'thing', 'panoptic', 'object')" ) return list( diff --git a/setup.py b/setup.py index f5bc265a82..869a00919b 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,7 @@ def get_version(): "starlette>=0.24.0", "strawberry-graphql", "tabulate", + "tifffile", "xmltodict", "universal-analytics-python3>=1.0.1,<2", "pydash", diff --git a/tests/unittests/label_tests.py b/tests/unittests/label_tests.py index de7ab8ed9a..b1da12269e 100644 --- a/tests/unittests/label_tests.py +++ b/tests/unittests/label_tests.py @@ -6,6 +6,8 @@ | """ import unittest +from pathlib import Path +from tempfile import TemporaryDirectory from bson import Binary, ObjectId import numpy as np @@ -19,6 +21,45 @@ from decorators import drop_datasets +def _make_panoptic(dtype=np.uint8): + max_value = np.iinfo(dtype).max + + instance_mask = np.zeros((8, 8), dtype=dtype) + instance_mask[1:2, 1:2] = 1 + instance_mask[2:3, 2:3] = 2 + instance_mask[3:4, 3:4] = 3 + class_mask = (instance_mask > 0).astype(dtype) + + class_mask[4:5, 4:5] = 1 + class_mask[5:6, 5:6] = max_value + class_mask[6:7, 6:7] = 1 + + panoptic_mask = np.stack([class_mask, instance_mask], axis=-1).astype( + dtype + ) + seg = fo.Segmentation(mask=panoptic_mask, is_panoptic=True) + + return seg + + +def _make_1d_segmentation(dtype=np.uint8): + max_value = np.iinfo(dtype).max + mask = np.zeros((4, 4), dtype=dtype) + mask[0:2, 0:2] = 1 + mask[2:4, 2:4] = max_value + seg = fo.Segmentation(mask=mask, is_panoptic=False) + return seg + + +def _make_3d_segmentation(dtype=np.uint8): + max_value = np.iinfo(dtype).max + mask = np.zeros((4, 4, 3), dtype=dtype) + mask[0:2, 0:2, 2] = 1 + mask[2:4, 2:4, :] = max_value + seg = fo.Segmentation(mask=mask, is_panoptic=False) + return seg + + class LabelTests(unittest.TestCase): @drop_datasets def test_id(self): @@ -532,6 +573,153 @@ def test_transform_mask(self): rgb_to_rgb = focl._transform_mask(int_to_rgb, targets_map) nptest.assert_array_equal(rgb_to_rgb, np.zeros((3, 3, 3), dtype=int)) + @drop_datasets + def test_panoptic_segmentation_conversion(self): + seg = _make_panoptic() + frame_size = seg.mask.shape[:2][::-1] + mask_targets = dict( + (int(idx), str(idx)) for idx in seg.mask[..., 0].flatten() + ) + + for mask_types in ( + None, + "panoptic", + "stuff", + "thing", + "object", + ): + if mask_types is None: + n_expected = 5 + elif mask_types == "panoptic": + n_expected = 5 + elif mask_types == "stuff": + n_expected = 2 + elif mask_types == "thing": + n_expected = 6 + elif mask_types == "object": + n_expected = 3 + + dets = seg.to_detections(mask_types=mask_types) + self.assertEqual(len(dets.detections), n_expected) + + sseg1 = dets.to_segmentation( + panoptic=False, + frame_size=frame_size, + mask_targets=mask_targets, + ) + pseg1 = dets.to_segmentation( + panoptic=True, frame_size=frame_size, mask_targets=mask_targets + ) + # TODO: check masks here + + single_seg = dets.detections[0].to_segmentation( + panoptic=True, frame_size=frame_size + ) + + poly = seg.to_polylines(mask_types=mask_types, tolerance=0) + self.assertEqual(len(poly.polylines), n_expected) + + sseg2 = poly.to_segmentation( + panoptic=False, + frame_size=frame_size, + mask_targets=mask_targets, + ) + pseg2 = poly.to_segmentation( + panoptic=True, frame_size=frame_size, mask_targets=mask_targets + ) + # TODO: check masks here + + single_seg = poly.polylines[0].to_segmentation( + panoptic=True, frame_size=frame_size + ) + + def test_1d_segmentation_conversion(self): + # 1d to panoptic + seg = _make_1d_segmentation() + pseg = seg.to_panoptic() + + class_mask = seg.mask + instance_mask = np.array( + [[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 2, 2], [0, 0, 2, 2]], dtype=int + ) + + assert np.all(pseg.mask[..., 0] == seg.mask) + assert np.all(pseg.mask[..., 1] == instance_mask) + + # back to 1d semantic + seg2 = pseg.to_semantic() + assert np.all(seg2.mask == seg.mask) + + # check that this throws an error + with self.assertRaises(ValueError): + seg2.to_detections(mask_types="panoptic") + + # to rgb semantic + seg3 = pseg.to_semantic(to_rgb=True) + assert np.all(seg3.mask[..., 2] == seg.mask) + assert np.all(seg3.mask[..., 1] == 0) + assert np.all(seg3.mask[..., 0] == 0) + + def test_3d_segmentation_conversion(self): + # 3d to panoptic + seg = _make_3d_segmentation() + pseg = seg.to_panoptic() + + x = 2**24 - 1 + class_mask = np.array( + [[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, x, x], [0, 0, x, x]], dtype=int + ) + + instance_mask = np.array( + [[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 2, 2], [0, 0, 2, 2]], dtype=int + ) + + assert np.all(pseg.mask[..., 0] == class_mask) + assert np.all(pseg.mask[..., 1] == instance_mask) + + seg2 = pseg.to_semantic(to_rgb=False) + assert np.all(seg2.mask == class_mask) + + seg3 = pseg.to_semantic(to_rgb=True) + assert np.all(seg3.mask == seg.mask) + + def test_segmentation_io(self): + def _test_io(dims, tif, dtype): + with TemporaryDirectory() as temp_dir: + if tif: + mask_path = Path(temp_dir) / "mask.tif" + else: + mask_path = Path(temp_dir) / "mask.tif" + + mask_path = str(mask_path) + + if dims == 1: + seg = _make_1d_segmentation(dtype=dtype) + if dims == 2: + seg = _make_panoptic(dtype=dtype) + if dims == 3: + seg = _make_3d_segmentation(dtype=dtype) + seg.export_mask(mask_path, update=False) + + seg2 = fo.Segmentation( + mask_path=mask_path, is_panoptic=(dims == 2) + ) + seg2.import_mask() + + assert np.all(seg.mask == seg2.mask) + + for dims in (1, 2, 3): + for tif in (False, True): + if dims == 3: + dtypes = [np.uint8] + elif tif: + dtypes = [np.uint8, np.uint16, np.uint32, np.uint64] + else: + dtypes = [np.uint8, np.uint16] + + for dtype in dtypes: + _test_io(dims=dims, tif=tif, dtype=dtype) + class LabelUtilsTests(unittest.TestCase): @drop_datasets