From 9b8c69330f5f0093e9c1a14bb239b8b555513b10 Mon Sep 17 00:00:00 2001 From: brimoor Date: Sat, 30 Nov 2024 22:35:24 -0600 Subject: [PATCH 1/5] add support for selecting/excluding group slices --- docs/source/user_guide/groups.rst | 61 +++++-- fiftyone/__public__.py | 1 + fiftyone/core/collections.py | 116 +++++++++++-- fiftyone/core/dataset.py | 11 +- fiftyone/core/stages.py | 260 ++++++++++++++++++++++++++++-- fiftyone/core/view.py | 64 +++++--- tests/unittests/group_tests.py | 133 ++++++++++++++- 7 files changed, 576 insertions(+), 70 deletions(-) diff --git a/docs/source/user_guide/groups.rst b/docs/source/user_guide/groups.rst index 2a4f927830..331a2f190e 100644 --- a/docs/source/user_guide/groups.rst +++ b/docs/source/user_guide/groups.rst @@ -815,8 +815,8 @@ Selecting slices You can use :meth:`select_group_slices() ` -to create *non-grouped views* that contain one or more slices of data from a -grouped dataset. +to select one or more slices of data from a grouped dataset, either as a +grouped view or as a flattened *non-grouped* view. For example, you can create an image view that contains only the left camera images from the grouped dataset: @@ -843,7 +843,7 @@ images from the grouped dataset: View stages: 1. SelectGroupSlices(slices='left') -or you could create an image collection containing the left and right camera +or you can create an image collection containing the left and right camera images: .. code-block:: python @@ -882,30 +882,59 @@ the fact that their data is sourced from a grouped dataset! # Add fields/tags, run evaluation, export, etc -Also note that any filtering that you apply prior to a +.. note:: + + Any filtering that you apply prior to a + :meth:`select_group_slices() ` + stage in a view is **not** automatically reflected by the output view, as + the stage looks up unfiltered slice data from the source collection: + + .. code-block:: python + + # Filter the active slice to locate groups of interest + match_view = dataset.filter_labels(...).match(...) + + # Lookup all image slices for the matching groups + # This view contains *unfiltered* image slices + images_view = match_view.select_group_slices(media_type="image") + + Instead, you can apply the same (or different) filtering *after* the + :meth:`select_group_slices() ` + stage: + + .. code-block:: python + + # Now apply filters to the flattened collection + match_images_view = images_view.filter_labels(...).match(...) + +Alternatively, you can pass `flat=False` to :meth:`select_group_slices() ` -stage in a view is **not** automatically reflected by the output view, as the -stage looks up unfiltered slice data from the source collection: +to create a grouped view that only contains certain group slices: .. code-block:: python :linenos: - # Filter the active slice to locate groups of interest - match_view = dataset.filter_labels(...).match(...) + no_center_view = dataset.select_group_slices(["left", "right"], flat=False) - # Lookup all image slices for the matching groups - # This view contains *unfiltered* image slices - images_view = match_view.select_group_slices(media_type="image") + assert no_center_view.media_type == "group" + assert no_center_view.group_slices == ["left", "right"] -Instead, you can apply the same (or different) filtering *after* the -:meth:`select_group_slices() ` -stage: +.. _groups-excluding-slices: + +Excluding slices +---------------- + +You can use +:meth:`exclude_group_slices() ` +to create a grouped view that excludes certain slice(s) of a grouped dataset: .. code-block:: python :linenos: - # Now apply filters to the flattened collection - match_images_view = images_view.filter_labels(...).match(...) + no_center_view = dataset.exclude_group_slices("center") + + assert no_center_view.media_type == "group" + assert no_center_view.group_slices == ["left", "right"] .. _groups-aggregations: diff --git a/fiftyone/__public__.py b/fiftyone/__public__.py index 1ccd9d15a5..4b6c107589 100644 --- a/fiftyone/__public__.py +++ b/fiftyone/__public__.py @@ -198,6 +198,7 @@ ExcludeFields, ExcludeFrames, ExcludeGroups, + ExcludeGroupSlices, ExcludeLabels, Exists, FilterField, diff --git a/fiftyone/core/collections.py b/fiftyone/core/collections.py index 2aafc32ee2..a657cc5c3a 100644 --- a/fiftyone/core/collections.py +++ b/fiftyone/core/collections.py @@ -4618,6 +4618,80 @@ def exclude_groups(self, group_ids): """ return self._add_view_stage(fos.ExcludeGroups(group_ids)) + @view_stage + def exclude_group_slices(self, slices=None, media_type=None): + """Excludes the specified group slice(s) from the grouped collection. + + Examples:: + + import fiftyone as fo + + dataset = fo.Dataset() + dataset.add_group_field("group", default="ego") + + group1 = fo.Group() + group2 = fo.Group() + + dataset.add_samples( + [ + fo.Sample( + filepath="/path/to/left-image1.jpg", + group=group1.element("left"), + ), + fo.Sample( + filepath="/path/to/video1.mp4", + group=group1.element("ego"), + ), + fo.Sample( + filepath="/path/to/right-image1.jpg", + group=group1.element("right"), + ), + fo.Sample( + filepath="/path/to/left-image2.jpg", + group=group2.element("left"), + ), + fo.Sample( + filepath="/path/to/video2.mp4", + group=group2.element("ego"), + ), + fo.Sample( + filepath="/path/to/right-image2.jpg", + group=group2.element("right"), + ), + ] + ) + + # + # Exclude the samples from the "ego" group slice + # + + view = dataset.exclude_group_slices("ego") + + # + # Exclude the samples from the "left" or "right" group slices + # + + view = dataset.exclude_group_slices(["left", "right"]) + + # + # Exclude all image slices + # + + view = dataset.exclude_group_slices(media_type="image") + + Args: + slices (None): a group slice or iterable of group slices to + exclude + media_type (None): a media type or iterable of media types whose + slice(s) to exclude + + Returns: + a :class:`fiftyone.core.view.DatasetView` + """ + return self._add_view_stage( + fos.ExcludeGroupSlices(slices=slices, media_type=media_type) + ) + @view_stage def exclude_labels( self, labels=None, ids=None, tags=None, fields=None, omit_empty=True @@ -6594,19 +6668,24 @@ def select_group_slices( self, slices=None, media_type=None, + flat=True, _allow_mixed=False, _force_mixed=False, ): - """Selects the samples in the group collection from the given slice(s). + """Selects the specified group slice(s) from the grouped collection. + + When ``flat==True``, the returned view is a flattened non-grouped view + containing the samples from the slice(s) of interest. - The returned view is a flattened non-grouped view containing only the - slice(s) of interest. + When ``flat=False``, the returned view is a grouped collection + containing only the slice(s) of interest. .. note:: - This stage performs a ``$lookup`` that pulls the requested slice(s) - for each sample in the input collection from the source dataset. - As a result, this stage always emits *unfiltered samples*. + When ``flat=True``, this stage performs a ``$lookup`` that pulls + the requested slice(s) for each sample in the input collection from + the source dataset. As a result, the stage emits + *unfiltered samples*. Examples:: @@ -6659,6 +6738,12 @@ def select_group_slices( view = dataset.select_group_slices(["left", "right"]) + # + # Select only the "left" and "right" group slices + # + + view = dataset.select_group_slices(["left", "right"], flat=False) + # # Retrieve all image samples # @@ -6669,7 +6754,10 @@ def select_group_slices( slices (None): a group slice or iterable of group slices to select. If neither argument is provided, a flattened list of all samples is returned - media_type (None): a media type whose slice(s) to select + media_type (None): a media type or iterable of media types whose + slice(s) to select + flat (True): whether to return a flattened collection (True) or a + grouped collection (False) Returns: a :class:`fiftyone.core.view.DatasetView` @@ -6678,6 +6766,7 @@ def select_group_slices( fos.SelectGroupSlices( slices=slices, media_type=media_type, + flat=flat, _allow_mixed=_allow_mixed, _force_mixed=_force_mixed, ) @@ -10542,24 +10631,23 @@ def _contains_media_type(self, media_type, any_slice=False): return True if self.media_type == fom.GROUP: - if self.group_media_types is None: + group_media_types = self.group_media_types + if group_media_types is None: return self._dataset.media_type == media_type if any_slice: return any( slice_media_type == media_type - for slice_media_type in self.group_media_types.values() + for slice_media_type in group_media_types.values() ) - return ( - self.group_media_types.get(self.group_slice, None) - == media_type - ) + return group_media_types.get(self.group_slice, None) == media_type if self.media_type == fom.MIXED: + group_media_types = self._get_group_media_types() return any( slice_media_type == media_type - for slice_media_type in self._get_group_media_types().values() + for slice_media_type in group_media_types.values() ) return False diff --git a/fiftyone/core/dataset.py b/fiftyone/core/dataset.py index 65e75cd910..dd20f51b5a 100644 --- a/fiftyone/core/dataset.py +++ b/fiftyone/core/dataset.py @@ -8484,7 +8484,6 @@ def _clone_collection(sample_collection, name, persistent): slug = _validate_dataset_name(name) contains_videos = sample_collection._contains_videos(any_slice=True) - contains_groups = sample_collection.media_type == fom.GROUP if isinstance(sample_collection, fov.DatasetView): dataset = sample_collection._dataset @@ -8492,6 +8491,9 @@ def _clone_collection(sample_collection, name, persistent): if view.media_type == fom.MIXED: raise ValueError("Cloning mixed views is not allowed") + + if view._is_dynamic_groups: + raise ValueError("Cloning dynamic grouped views is not allowed") else: dataset = sample_collection view = None @@ -8525,10 +8527,9 @@ def _clone_collection(sample_collection, name, persistent): dataset_doc.sample_collection_name = sample_collection_name dataset_doc.frame_collection_name = frame_collection_name dataset_doc.media_type = sample_collection.media_type - if not contains_groups: - dataset_doc.group_field = None - dataset_doc.group_media_types = {} - dataset_doc.default_group_slice = None + dataset_doc.group_field = sample_collection.group_field + dataset_doc.group_media_types = sample_collection.group_media_types + dataset_doc.default_group_slice = sample_collection.default_group_slice for field in dataset_doc.sample_fields: field._set_created_at(now) diff --git a/fiftyone/core/stages.py b/fiftyone/core/stages.py index a71272bf15..7b726fa1a0 100644 --- a/fiftyone/core/stages.py +++ b/fiftyone/core/stages.py @@ -98,6 +98,18 @@ def outputs_dynamic_groups(self): """ return None + @property + def flattens_groups(self): + """Whether this stage flattens groups into a non-grouped collection. + + The possible return values are: + + - ``True``: this stage *flattens* groups + - ``False``: this stage *does not flatten* groups + - ``None``: this stage does not change group status + """ + return None + def get_edited_fields(self, sample_collection, frames=False): """Returns a list of names of fields or embedded fields that may have been edited by the stage, if any. @@ -210,6 +222,22 @@ def get_group_expr(self, sample_collection): """ return None, None + def get_group_media_types(self, sample_collection): + """Returns the group media types outputted by this stage, if any, when + applied to the given collection, if and only if they are different from + the input collection. + + Args: + sample_collection: the + :class:`fiftyone.core.collections.SampleCollection` to which + the stage is being applied + + Returns: + a dict mapping slice names to media types, or ``None`` if the stage + does not change the types + """ + return None + def load_view(self, sample_collection): """Loads the :class:`fiftyone.core.view.DatasetView` containing the output of the stage. @@ -4550,16 +4578,19 @@ def _params(cls): class SelectGroupSlices(ViewStage): - """Selects the samples in a group collection from the given slice(s). + """Selects the specified group slice(s) from a grouped collection. + + When ``flat==True``, the returned view is a flattened non-grouped view + containing the samples from the slice(s) of interest. - The returned view is a flattened non-grouped view containing only the - slice(s) of interest. + When ``flat=False``, the returned view is a grouped collection containing + only the slice(s) of interest. .. note:: - This stage performs a ``$lookup`` that pulls the requested slice(s) for - each sample in the input collection from the source dataset. As a - result, this stage always emits *unfiltered samples*. + When ``flat=True``, this stage performs a ``$lookup`` that pulls the + requested slice(s) for each sample in the input collection from the + source dataset. As a result, the stage emits *unfiltered samples*. Examples:: @@ -4614,6 +4645,13 @@ class SelectGroupSlices(ViewStage): stage = fo.SelectGroupSlices(["left", "right"]) view = dataset.add_stage(stage) + # + # Select only the "left" and "right" group slices + # + + stage = fo.SelectGroupSlices(["left", "right"], flat=False) + view = dataset.add_stage(stage) + # # Retrieve all image samples # @@ -4625,18 +4663,23 @@ class SelectGroupSlices(ViewStage): slices (None): a group slice or iterable of group slices to select. If neither argument is provided, a flattened list of all samples is returned - media_type (None): a media type whose slice(s) to select + media_type (None): a media type or iterable of media types whose + slice(s) to select + flat (True): whether to return a flattened collection (True) or a + grouped collection (False) """ def __init__( self, slices=None, media_type=None, + flat=True, _allow_mixed=False, _force_mixed=False, ): self._slices = slices self._media_type = media_type + self._flat = flat self._allow_mixed = _allow_mixed self._force_mixed = _force_mixed @@ -4647,10 +4690,22 @@ def slices(self): @property def media_type(self): - """The media type whose slices to select.""" + """The media type(s) whose slices to select.""" return self._media_type + @property + def flat(self): + """Whether to generate a flattened collection.""" + return self._flat + + @property + def flattens_groups(self): + return self._flat + def to_mongo(self, sample_collection): + if not self._flat: + return [] + if isinstance(sample_collection, fod.Dataset) or ( isinstance(sample_collection, fov.DatasetView) and len(sample_collection._stages) == 0 @@ -4677,6 +4732,13 @@ def _make_pipeline(self, sample_collection): name_field = group_field + ".name" slices = self._get_slices(sample_collection) + + # No $lookup needed because active slice is the only one requested + if ( + etau.is_str(slices) and slices == sample_collection.group_slice + ) or (slices is None and len(sample_collection.group_slices) == 1): + return [] + expr = F(id_field) == "$$group_id" if isinstance(slices, list): expr &= F(name_field).is_in(slices) @@ -4697,7 +4759,6 @@ def _make_pipeline(self, sample_collection): {"$replaceRoot": {"newRoot": "$groups"}}, ] - # @note(SelectGroupSlices) # Must re-apply field selection/exclusion after the $lookup if isinstance(sample_collection, fov.DatasetView): selected_fields, excluded_fields = _get_selected_excluded_fields( @@ -4715,6 +4776,9 @@ def _make_pipeline(self, sample_collection): return pipeline def get_media_type(self, sample_collection): + if not self._flat: + return sample_collection.media_type + if self._force_mixed: return fom.MIXED @@ -4776,15 +4840,25 @@ def validate(self, sample_collection): def _get_slices(self, sample_collection): if self._media_type is not None: + if etau.is_str(self._media_type): + media_types = {self._media_type} + else: + media_types = set(self._media_type) + group_media_types = sample_collection.group_media_types slices = [ slice_name for slice_name, media_type in group_media_types.items() - if media_type == self._media_type + if media_type in media_types ] else: slices = self._slices + if slices is None: + group_slices = sample_collection.group_slices + if group_slices != sample_collection._dataset.group_slices: + slices = group_slices + if not etau.is_container(slices): return slices @@ -4795,7 +4869,7 @@ def _get_slices(self, sample_collection): return slices - def _get_group_media_types(self, sample_collection): + def get_group_media_types(self, sample_collection): group_media_types = sample_collection.group_media_types slices = self._get_slices(sample_collection) @@ -4816,6 +4890,7 @@ def _kwargs(self): return [ ["slices", self._slices], ["media_type", self._media_type], + ["flat", self._flat], ["_allow_mixed", self._allow_mixed], ["_force_mixed", self._force_mixed], ] @@ -4831,10 +4906,16 @@ def _params(cls): }, { "name": "media_type", - "type": "NoneType|str", + "type": "NoneType|list|str", "placeholder": "media_type (default=None)", "default": "None", }, + { + "name": "flat", + "type": "bool", + "default": "True", + "placeholder": "flat (default=True)", + }, { "name": "_allow_mixed", "type": "NoneType|bool", @@ -4848,6 +4929,158 @@ def _params(cls): ] +class ExcludeGroupSlices(ViewStage): + """Excludes the specified group slice(s) from a grouped collection. + + Examples:: + + import fiftyone as fo + + dataset = fo.Dataset() + dataset.add_group_field("group", default="ego") + + group1 = fo.Group() + group2 = fo.Group() + + dataset.add_samples( + [ + fo.Sample( + filepath="/path/to/left-image1.jpg", + group=group1.element("left"), + ), + fo.Sample( + filepath="/path/to/video1.mp4", + group=group1.element("ego"), + ), + fo.Sample( + filepath="/path/to/right-image1.jpg", + group=group1.element("right"), + ), + fo.Sample( + filepath="/path/to/left-image2.jpg", + group=group2.element("left"), + ), + fo.Sample( + filepath="/path/to/video2.mp4", + group=group2.element("ego"), + ), + fo.Sample( + filepath="/path/to/right-image2.jpg", + group=group2.element("right"), + ), + ] + ) + + # + # Exclude the "ego" group slice + # + + stage = fo.ExcludeGroupSlices("ego") + view = dataset.add_stage(stage) + + # + # Exclude the "left" and "right" group slices + # + + stage = fo.ExcludeGroupSlices(["left", "right"]) + view = dataset.add_stage(stage) + + # + # Exclude all image slices + # + + stage = fo.ExcludeGroupSlices(media_type="image") + view = dataset.add_stage(stage) + + Args: + slices (None): a group slice or iterable of group slices to exclude. + media_type (None): a media type or iterable of media types whose + slice(s) to exclude + """ + + def __init__(self, slices=None, media_type=None): + self._slices = slices + self._media_type = media_type + + @property + def slices(self): + """The group slice(s) to exclude.""" + return self._slices + + @property + def media_type(self): + """The media type(s) whose slices to exclude.""" + return self._media_type + + def to_mongo(self, sample_collection): + return [] + + def validate(self, sample_collection): + if sample_collection.media_type != fom.GROUP: + raise ValueError("%s has no groups" % type(sample_collection)) + + def _get_slices(self, sample_collection): + if self._media_type is not None: + if etau.is_str(self._media_type): + media_types = {self._media_type} + else: + media_types = set(self._media_type) + + group_media_types = sample_collection.group_media_types + return [ + slice_name + for slice_name, media_type in group_media_types.items() + if media_type not in media_types + ] + + if self._slices is not None: + if etau.is_str(self._slices): + slices = {self._slices} + else: + slices = set(self._slices) + + return [ + slice_name + for slice_name in sample_collection.group_slices + if slice_name not in slices + ] + + return sample_collection.group_slices + + def get_group_media_types(self, sample_collection): + group_media_types = sample_collection.group_media_types + slices = set(self._get_slices(sample_collection)) + + return { + slice_name: media_type + for slice_name, media_type in group_media_types.items() + if slice_name in slices + } + + def _kwargs(self): + return [ + ["slices", self._slices], + ["media_type", self._media_type], + ] + + @classmethod + def _params(cls): + return [ + { + "name": "slices", + "type": "NoneType|list|str", + "placeholder": "slices (default=None)", + "default": "None", + }, + { + "name": "media_type", + "type": "NoneType|list|str", + "placeholder": "media_type (default=None)", + "default": "None", + }, + ] + + class MatchFrames(ViewStage): """Filters the frames in a video collection by the given filter. @@ -8611,6 +8844,7 @@ def repr_ViewExpression(self, expr, level): ExcludeFields, ExcludeFrames, ExcludeGroups, + ExcludeGroupSlices, ExcludeLabels, Exists, FilterField, @@ -8658,6 +8892,7 @@ def repr_ViewExpression(self, expr, level): # View stages that only select documents Exclude, ExcludeBy, + ExcludeGroupSlices, Exists, GeoNear, GeoWithin, @@ -8674,5 +8909,6 @@ def repr_ViewExpression(self, expr, level): # Registry of select stages that should select first _STAGES_THAT_SELECT_FIRST = { + ExcludeGroupSlices, SelectGroupSlices, } diff --git a/fiftyone/core/view.py b/fiftyone/core/view.py index a7ba59c983..cb277ad0b5 100644 --- a/fiftyone/core/view.py +++ b/fiftyone/core/view.py @@ -162,7 +162,7 @@ def _has_slices(self): return False for stage in self._stages: - if isinstance(stage, fost.SelectGroupSlices): + if stage.flattens_groups: return False return True @@ -244,7 +244,7 @@ def group_slice(self): return None if self.__group_slice is not None: - return self.__group_slice + return self.__group_slice or None return self._dataset.group_slice @@ -271,7 +271,7 @@ def group_slices(self): if not self._has_slices: return None - return self._dataset.group_slices + return list(self._get_group_media_types().keys()) @property def group_media_types(self): @@ -281,7 +281,7 @@ def group_media_types(self): if not self._has_slices: return None - return self._dataset.group_media_types + return self._get_group_media_types() @property def default_group_slice(self): @@ -291,6 +291,9 @@ def default_group_slice(self): if not self._has_slices: return None + if self._dataset.default_group_slice not in self.group_slices: + return self.group_slice + return self._dataset.default_group_slice @property @@ -1557,7 +1560,7 @@ def _pipeline( _view = self._base_view _contains_videos = self._dataset._contains_videos(any_slice=True) - _found_select_group_slice = False + _found_flattened_videos = False _attach_frames_idx = None _attach_frames_idx0 = None _attach_frames_idx1 = None @@ -1573,10 +1576,12 @@ def _pipeline( idx = 0 for stage in self._stages: - if isinstance(stage, fost.SelectGroupSlices): - # We might need to reattach frames after `SelectGroupSlices`, - # since it involves a `$lookup` that resets the samples - _found_select_group_slice = True + _pipeline = stage.to_mongo(_view) + + if stage.flattens_groups and _contains_videos and _pipeline: + # We might need to reattach frames after flattening groups + # since this involves a `$lookup` that resets the samples + _found_flattened_videos = True _attach_frames_idx0 = _attach_frames_idx _attach_frames_idx = None @@ -1604,15 +1609,12 @@ def _pipeline( _group_slices.update(_stage_group_slices) - _pipeline = stage.to_mongo(_view) - - # @note(SelectGroupSlices) - # Special case: when selecting group slices of a video dataset that + # Special case: when flattening group slices of a video dataset that # modifies the dataset's schema, frame lookups must be injected in # the middle of the stage's pipeline, after the group slice $lookup # but *before* the $project stage(s) that reapply schema changes if ( - isinstance(stage, fost.SelectGroupSlices) + stage.flattens_groups and _contains_videos and _pipeline and "$project" in _pipeline[-1] @@ -1646,8 +1648,8 @@ def _pipeline( attach_frames = True _pipeline = self._dataset._attach_frames_pipeline(support=support) _pipelines.insert(_attach_frames_idx, _pipeline) - elif _found_select_group_slice and _attach_frames_idx is not None: - # Must manually attach frames after the group selection + elif _found_flattened_videos and _attach_frames_idx is not None: + # Must manually attach frames after the group $lookup attach_frames = None # special syntax: frames already attached _pipeline = self._dataset._attach_frames_pipeline(support=support) _pipelines.insert(_attach_frames_idx, _pipeline) @@ -1699,7 +1701,7 @@ def _pipeline( media_type = self.media_type if group_slice is None and self._dataset.media_type == fom.GROUP: - group_slice = self.__group_slice or self._dataset.group_slice + group_slice = self.__group_slice return self._dataset._pipeline( pipeline=_pipeline, @@ -1812,6 +1814,15 @@ def _add_view_stage(self, stage, validate=True): if media_type is not None: view._set_media_type(media_type) + group_media_types = stage.get_group_media_types(self) + if ( + group_media_types is not None + and view.media_type == fom.GROUP + and view.group_slice not in group_media_types + ): + group_slice = next(iter(group_media_types.keys()), "") + view._set_group_slice(group_slice) + view._set_name(None) return view @@ -1819,6 +1830,9 @@ def _add_view_stage(self, stage, validate=True): def _set_media_type(self, media_type): self.__media_type = media_type + def _set_group_slice(self, slice_name): + self.__group_slice = slice_name + def _set_name(self, name): self.__name = name @@ -1922,11 +1936,17 @@ def _get_missing_fields(self, frames=False): return missing_fields def _get_group_media_types(self): - for stage in reversed(self._stages): - if isinstance(stage, fost.SelectGroupSlices): - return stage._get_group_media_types(self._dataset) + group_media_types = self._dataset.group_media_types + + _view = self._base_view + for stage in self._stages: + gmt = stage.get_group_media_types(_view) + if gmt is not None: + group_media_types = gmt + + _view = _view._add_view_stage(stage, validate=False) - return self._dataset.group_media_types + return group_media_types def make_optimized_select_view( @@ -1953,7 +1973,7 @@ def make_optimized_select_view( match the order of the provided IDs groups (False): whether the IDs are group IDs, not sample IDs flatten (False): whether to flatten group datasets before selecting - sample ids + sample IDs Returns: a :class:`DatasetView` diff --git a/tests/unittests/group_tests.py b/tests/unittests/group_tests.py index fbc66af2ba..b3163fe7a2 100644 --- a/tests/unittests/group_tests.py +++ b/tests/unittests/group_tests.py @@ -588,6 +588,18 @@ def test_field_schemas(self): view = dataset.select_fields() + # Selecting active slice maintains schema changes + video_view = view.select_group_slices("ego") + + self.assertEqual(view.group_slice, "ego") + self.assertEqual(video_view.media_type, "video") + self.assertNotIn("field", video_view.get_field_schema()) + self.assertNotIn("field", video_view.get_frame_field_schema()) + for sample in video_view: + self.assertFalse(sample.has_field("field")) + for frame in sample.frames.values(): + self.assertFalse(frame.has_field("field")) + # Cloning a grouped dataset maintains schema changes group_dataset = view.clone() @@ -611,7 +623,6 @@ def test_field_schemas(self): for sample in image_dataset: self.assertFalse(sample.has_field("field")) - # @note(SelectGroupSlices) # Selecting group slices maintains frame schema changes video_view = view.select_group_slices(media_type="video") @@ -631,6 +642,126 @@ def test_field_schemas(self): for frame in sample.frames.values(): self.assertFalse(frame.has_field("field")) + @drop_datasets + def test_select_exclude_slices(self): + dataset = _make_group_dataset() + + # Select slices by name + view = dataset.select_group_slices(["left", "right"], flat=False) + + self.assertEqual(len(view), 2) + self.assertEqual(view.media_type, "group") + self.assertSetEqual(set(view.group_slices), {"left", "right"}) + self.assertDictEqual( + view.group_media_types, {"left": "image", "right": "image"} + ) + self.assertIn(view.group_slice, ["left", "right"]) + self.assertIn(view.default_group_slice, ["left", "right"]) + + # Select slices by media type + view = dataset.select_group_slices(media_type="image", flat=False) + + self.assertEqual(len(view), 2) + self.assertEqual(view.media_type, "group") + self.assertSetEqual(set(view.group_slices), {"left", "right"}) + self.assertDictEqual( + view.group_media_types, {"left": "image", "right": "image"} + ) + self.assertIn(view.group_slice, ["left", "right"]) + self.assertIn(view.default_group_slice, ["left", "right"]) + + # Exclude slices by name + view = dataset.exclude_group_slices("ego") + + self.assertEqual(len(view), 2) + self.assertEqual(view.media_type, "group") + self.assertSetEqual(set(view.group_slices), {"left", "right"}) + self.assertDictEqual( + view.group_media_types, {"left": "image", "right": "image"} + ) + self.assertIn(view.group_slice, ["left", "right"]) + self.assertIn(view.default_group_slice, ["left", "right"]) + + # Exclude slices by media type + view = dataset.exclude_group_slices(media_type="video") + + self.assertEqual(len(view), 2) + self.assertEqual(view.media_type, "group") + self.assertSetEqual(set(view.group_slices), {"left", "right"}) + self.assertDictEqual( + view.group_media_types, {"left": "image", "right": "image"} + ) + self.assertIn(view.group_slice, ["left", "right"]) + self.assertIn(view.default_group_slice, ["left", "right"]) + + # Empty grouped view + view = dataset.select_group_slices( + ["left", "right"], flat=False + ).exclude_group_slices(media_type="image") + + self.assertEqual(len(view), 0) + self.assertEqual(view.media_type, "group") + self.assertListEqual(view.group_slices, []) + self.assertDictEqual(view.group_media_types, {}) + self.assertIsNone(view.group_slice) + self.assertIsNone(view.default_group_slice) + + # Empty grouped view clone + dataset2 = view.clone() + + self.assertEqual(len(dataset2), 0) + self.assertEqual(dataset2.media_type, "group") + self.assertListEqual(dataset2.group_slices, []) + self.assertDictEqual(dataset2.group_media_types, {}) + self.assertIsNone(dataset2.group_slice) + self.assertIsNone(dataset2.default_group_slice) + + # Select group slices with filtered schema + view = dataset.select_fields().select_group_slices( + media_type="video", flat=False + ) + + self.assertEqual(len(view), 2) + self.assertEqual(view.media_type, "group") + self.assertListEqual(view.group_slices, ["ego"]) + self.assertDictEqual(view.group_media_types, {"ego": "video"}) + self.assertEqual(view.group_slice, "ego") + self.assertEqual(view.default_group_slice, "ego") + + schema = view.get_field_schema() + frame_schema = view.get_frame_field_schema() + + self.assertNotIn("field", schema) + self.assertNotIn("field", frame_schema) + + sample_view = view.first() + frame_view = sample_view.frames.first() + + self.assertFalse(sample_view.has_field("field")) + self.assertFalse(frame_view.has_field("field")) + + # Clone selected group slices with filtered schema + dataset2 = view.clone() + + self.assertEqual(len(dataset2), 2) + self.assertEqual(dataset2.media_type, "group") + self.assertListEqual(dataset2.group_slices, ["ego"]) + self.assertDictEqual(dataset2.group_media_types, {"ego": "video"}) + self.assertEqual(dataset2.group_slice, "ego") + self.assertEqual(dataset2.default_group_slice, "ego") + + schema = dataset2.get_field_schema() + frame_schema = dataset2.get_frame_field_schema() + + self.assertNotIn("field", schema) + self.assertNotIn("field", frame_schema) + + sample2 = dataset2.first() + frame2 = sample2.frames.first() + + self.assertFalse(sample2.has_field("field")) + self.assertFalse(frame2.has_field("field")) + @drop_datasets def test_attached_groups(self): dataset = _make_group_dataset() From 63a55e2c217a29df386df143a33ae29b13a42c13 Mon Sep 17 00:00:00 2001 From: Benjamin Kane Date: Fri, 24 Jan 2025 10:32:14 -0500 Subject: [PATCH 2/5] server group media types --- fiftyone/server/query.py | 13 +-- tests/unittests/server_dataset_tests.py | 86 +++++++++++++++++++ tests/unittests/server_group_tests.py | 1 + ...ing_tests.py => server_lightning_tests.py} | 0 tests/unittests/server_state_tests.py | 2 +- .../{server_tests.py => server_view_tests.py} | 2 +- 6 files changed, 97 insertions(+), 7 deletions(-) create mode 100644 tests/unittests/server_dataset_tests.py rename tests/unittests/{lightning_tests.py => server_lightning_tests.py} (100%) rename tests/unittests/{server_tests.py => server_view_tests.py} (99%) diff --git a/fiftyone/server/query.py b/fiftyone/server/query.py index 3960796c0a..da0873c189 100644 --- a/fiftyone/server/query.py +++ b/fiftyone/server/query.py @@ -320,10 +320,7 @@ def modifier(doc: dict) -> dict: dict(name=name, **data) for name, data in doc.get("skeletons", {}).items() ) - doc["group_media_types"] = [ - Group(name=name, media_type=media_type) - for name, media_type in doc.get("group_media_types", {}).items() - ] + doc["group_media_types"] = [] doc["default_skeletons"] = doc.get("default_skeletons", None) # gql private fields must always be present @@ -597,7 +594,9 @@ def run(): for stage in serialized_view: view = view.add_stage(fosg.ViewStage._from_dict(stage)) except: - view = fov.DatasetView._build(dataset, serialized_view or []) + view: fov.DatasetView = fov.DatasetView._build( + dataset, serialized_view or [] + ) doc = dataset._doc.to_dict(no_dereference=True) Dataset.modifier(doc) @@ -605,6 +604,10 @@ def run(): data.view_cls = None data.view_name = view_name data.saved_view_slug = saved_view_slug + data.group_media_types = [ + Group(name=name, media_type=media_type) + for name, media_type in view._get_group_media_types().items() + ] collection = dataset.view() if view is not None: diff --git a/tests/unittests/server_dataset_tests.py b/tests/unittests/server_dataset_tests.py new file mode 100644 index 0000000000..f530cb87a1 --- /dev/null +++ b/tests/unittests/server_dataset_tests.py @@ -0,0 +1,86 @@ +""" +FiftyOne server dataset tests. + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +import typing as t + +import unittest +import strawberry as gql +from strawberry.schema.config import StrawberryConfig + +import fiftyone as fo +import fiftyone.core.media as fom + +from fiftyone.server.constants import SCALAR_OVERRIDES +from fiftyone.server.scalars import BSONArray +from fiftyone.server.query import Dataset + +from decorators import drop_async_dataset +from utils.graphql import execute + + +@gql.type +class DatasetQuery: + dataset: Dataset = gql.field(resolver=Dataset.resolver) + + +schema = gql.Schema( + query=DatasetQuery, + scalar_overrides=SCALAR_OVERRIDES, + config=StrawberryConfig(auto_camel_case=False), +) + +MEDIA_TYPES = {media_type: media_type for media_type in fom.MEDIA_TYPES} +MEDIA_TYPES[fom.POINT_CLOUD] = "point_cloud" +MEDIA_TYPES[fom.THREE_D] = "three_d" + + +class TestDataset(unittest.IsolatedAsyncioTestCase): + @drop_async_dataset + async def test_group_media_types(self, dataset: fo.Dataset): + dataset.media_type = "group" + for media_type in MEDIA_TYPES: + dataset.add_group_slice(media_type, media_type) + + query = """ + query Query($name: String!, $view: BSONArray) { + dataset(name: $name, view: $view) { + group_media_types { + media_type + name + } + } + } + """ + + response = lambda media_type: { + "group_media_types": [ + {"media_type": MEDIA_TYPES[media_type], "name": media_type} + ] + } + asserter = lambda result, media_type: self.assertEqual( + result.data["dataset"], response(media_type) + ) + + for media_type in fom.MEDIA_TYPES: + view = dataset.select_group_slices(slices=media_type, flat=False) + result = await _execute( + query, dataset.name, view=view._serialize() + ) + asserter(result, media_type) + + view = dataset.select_group_slices( + media_type=media_type, flat=False + ) + result = await _execute( + query, dataset.name, view=view._serialize() + ) + asserter(result, media_type) + + +async def _execute(query: str, name: str, view: t.Optional[BSONArray] = None): + return await execute(schema, query, variables={"name": name, "view": view}) diff --git a/tests/unittests/server_group_tests.py b/tests/unittests/server_group_tests.py index b65b020799..7a7af2f5fa 100644 --- a/tests/unittests/server_group_tests.py +++ b/tests/unittests/server_group_tests.py @@ -12,6 +12,7 @@ import fiftyone as fo from fiftyone import ViewExpression as F + from fiftyone.server.aggregations import GroupElementFilter import fiftyone.server.view as fosv diff --git a/tests/unittests/lightning_tests.py b/tests/unittests/server_lightning_tests.py similarity index 100% rename from tests/unittests/lightning_tests.py rename to tests/unittests/server_lightning_tests.py diff --git a/tests/unittests/server_state_tests.py b/tests/unittests/server_state_tests.py index 2913184b99..c2fbbf2fc9 100644 --- a/tests/unittests/server_state_tests.py +++ b/tests/unittests/server_state_tests.py @@ -1,5 +1,5 @@ """ -FiftyOne server state tests. +FiftyOne Server state tests. | Copyright 2017-2025, Voxel51, Inc. | `voxel51.com `_ diff --git a/tests/unittests/server_tests.py b/tests/unittests/server_view_tests.py similarity index 99% rename from tests/unittests/server_tests.py rename to tests/unittests/server_view_tests.py index b09f80f290..11c6f35edd 100644 --- a/tests/unittests/server_tests.py +++ b/tests/unittests/server_view_tests.py @@ -1,5 +1,5 @@ """ -FiftyOne server-related unit tests. +FiftyOne Server view tests. | Copyright 2017-2025, Voxel51, Inc. | `voxel51.com `_ From a78eb7be8720e59c3702cc9d2db2819cc36cc416 Mon Sep 17 00:00:00 2001 From: Benjamin Kane Date: Fri, 24 Jan 2025 10:38:28 -0500 Subject: [PATCH 3/5] support exclude group slices --- fiftyone/server/view.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fiftyone/server/view.py b/fiftyone/server/view.py index 1ecda7797c..aa743c4b71 100644 --- a/fiftyone/server/view.py +++ b/fiftyone/server/view.py @@ -250,7 +250,8 @@ def handle_group_filter( group_field = dataset.group_field unselected = not any( - isinstance(stage, fosg.SelectGroupSlices) for stage in stages + isinstance(stage, (fosg.ExcludeGroupSlices, fosg.SelectGroupSlices)) + for stage in stages ) group_by = any(isinstance(stage, fosg.GroupBy) for stage in stages) From 98918196642a188c10fa4a1c5068449d52a0da60 Mon Sep 17 00:00:00 2001 From: Benjamin Kane Date: Fri, 24 Jan 2025 11:51:52 -0500 Subject: [PATCH 4/5] update handle group filter --- fiftyone/server/view.py | 16 ++++++++---- tests/unittests/server_group_tests.py | 37 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/fiftyone/server/view.py b/fiftyone/server/view.py index aa743c4b71..8dc37439cc 100644 --- a/fiftyone/server/view.py +++ b/fiftyone/server/view.py @@ -249,17 +249,18 @@ def handle_group_filter( stages = view._stages group_field = dataset.group_field - unselected = not any( - isinstance(stage, (fosg.ExcludeGroupSlices, fosg.SelectGroupSlices)) - for stage in stages - ) + selected = False + for stage in stages: + if isinstance(stage, fosg.SelectGroupSlices) and stage.flat: + selected = True + group_by = any(isinstance(stage, fosg.GroupBy) for stage in stages) view = dataset.view() if filter.slice: view.group_slice = filter.slice - if unselected and filter.slices: + if not selected and filter.slices: # flatten the collection if the view has no slice(s) selected view = dataset.select_group_slices(_force_mixed=True) @@ -276,6 +277,11 @@ def handle_group_filter( {group_field + ".name": {"$in": filter.slices}} ) + if isinstance( + stage, (fosg.SelectGroupSlices, fosg.ExcludeGroupSlices) + ): + continue + # if selecting a group, filter out select/reorder stages if ( not filter.id diff --git a/tests/unittests/server_group_tests.py b/tests/unittests/server_group_tests.py index 7a7af2f5fa..d319c73483 100644 --- a/tests/unittests/server_group_tests.py +++ b/tests/unittests/server_group_tests.py @@ -122,3 +122,40 @@ def test_group_selection(self): GroupElementFilter(id=group.id, slices=["one", "two"]), ) self.assertEqual(len(without_slices), 2) + + @drop_datasets + def test_slice_selection(self): + dataset: fo.Dataset = fo.Dataset() + dataset.media_type = "group" + dataset.add_group_slice("one", "image") + dataset.add_group_slice("two", "image") + dataset.add_group_slice("three", "image") + + group = fo.Group() + one = fo.Sample( + filepath="image.png", + group=group.element("one"), + ) + two = fo.Sample( + filepath="image.png", + group=group.element("two"), + ) + three = fo.Sample( + filepath="image.png", + group=group.element("three"), + ) + dataset.add_samples([one, two, three]) + + exclude_three, _ = fosv.handle_group_filter( + dataset, + dataset.exclude_group_slices("three"), + GroupElementFilter(id=group.id, slices=["one", "two"]), + ) + self.assertEqual(len(exclude_three), 2) + + select_one_two, _ = fosv.handle_group_filter( + dataset, + dataset.select_group_slices(("one", "two"), flat=False), + GroupElementFilter(id=group.id, slices=["one", "two"]), + ) + self.assertEqual(len(select_one_two), 2) From 1c5b695fc51cf3038d18a0b572bfd72658571bcd Mon Sep 17 00:00:00 2001 From: Benjamin Kane Date: Fri, 24 Jan 2025 12:09:14 -0500 Subject: [PATCH 5/5] handle none --- fiftyone/core/view.py | 1 + fiftyone/server/query.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/fiftyone/core/view.py b/fiftyone/core/view.py index d77b6ba175..438acbb8fe 100644 --- a/fiftyone/core/view.py +++ b/fiftyone/core/view.py @@ -5,6 +5,7 @@ | `voxel51.com `_ | """ + from collections import defaultdict, OrderedDict import contextlib from copy import copy, deepcopy diff --git a/fiftyone/server/query.py b/fiftyone/server/query.py index da0873c189..824f26d52b 100644 --- a/fiftyone/server/query.py +++ b/fiftyone/server/query.py @@ -604,9 +604,11 @@ def run(): data.view_cls = None data.view_name = view_name data.saved_view_slug = saved_view_slug + + group_media_types = view._get_group_media_types() or {} data.group_media_types = [ Group(name=name, media_type=media_type) - for name, media_type in view._get_group_media_types().items() + for name, media_type in group_media_types.items() ] collection = dataset.view()