From 73a775ba494aa10d05b3d81fc80e4b0f3c0e7cea Mon Sep 17 00:00:00 2001 From: brimoor Date: Sat, 30 Nov 2024 22:35:24 -0600 Subject: [PATCH] add support for selecting/excluding group slices --- docs/source/user_guide/groups.rst | 61 +++++-- fiftyone/__public__.py | 1 + fiftyone/core/collections.py | 116 +++++++++++-- fiftyone/core/dataset.py | 11 +- fiftyone/core/stages.py | 260 ++++++++++++++++++++++++++++-- fiftyone/core/view.py | 64 +++++--- tests/unittests/group_tests.py | 133 ++++++++++++++- 7 files changed, 576 insertions(+), 70 deletions(-) diff --git a/docs/source/user_guide/groups.rst b/docs/source/user_guide/groups.rst index 2a4f927830..331a2f190e 100644 --- a/docs/source/user_guide/groups.rst +++ b/docs/source/user_guide/groups.rst @@ -815,8 +815,8 @@ Selecting slices You can use :meth:`select_group_slices() ` -to create *non-grouped views* that contain one or more slices of data from a -grouped dataset. +to select one or more slices of data from a grouped dataset, either as a +grouped view or as a flattened *non-grouped* view. For example, you can create an image view that contains only the left camera images from the grouped dataset: @@ -843,7 +843,7 @@ images from the grouped dataset: View stages: 1. SelectGroupSlices(slices='left') -or you could create an image collection containing the left and right camera +or you can create an image collection containing the left and right camera images: .. code-block:: python @@ -882,30 +882,59 @@ the fact that their data is sourced from a grouped dataset! # Add fields/tags, run evaluation, export, etc -Also note that any filtering that you apply prior to a +.. note:: + + Any filtering that you apply prior to a + :meth:`select_group_slices() ` + stage in a view is **not** automatically reflected by the output view, as + the stage looks up unfiltered slice data from the source collection: + + .. code-block:: python + + # Filter the active slice to locate groups of interest + match_view = dataset.filter_labels(...).match(...) + + # Lookup all image slices for the matching groups + # This view contains *unfiltered* image slices + images_view = match_view.select_group_slices(media_type="image") + + Instead, you can apply the same (or different) filtering *after* the + :meth:`select_group_slices() ` + stage: + + .. code-block:: python + + # Now apply filters to the flattened collection + match_images_view = images_view.filter_labels(...).match(...) + +Alternatively, you can pass `flat=False` to :meth:`select_group_slices() ` -stage in a view is **not** automatically reflected by the output view, as the -stage looks up unfiltered slice data from the source collection: +to create a grouped view that only contains certain group slices: .. code-block:: python :linenos: - # Filter the active slice to locate groups of interest - match_view = dataset.filter_labels(...).match(...) + no_center_view = dataset.select_group_slices(["left", "right"], flat=False) - # Lookup all image slices for the matching groups - # This view contains *unfiltered* image slices - images_view = match_view.select_group_slices(media_type="image") + assert no_center_view.media_type == "group" + assert no_center_view.group_slices == ["left", "right"] -Instead, you can apply the same (or different) filtering *after* the -:meth:`select_group_slices() ` -stage: +.. _groups-excluding-slices: + +Excluding slices +---------------- + +You can use +:meth:`exclude_group_slices() ` +to create a grouped view that excludes certain slice(s) of a grouped dataset: .. code-block:: python :linenos: - # Now apply filters to the flattened collection - match_images_view = images_view.filter_labels(...).match(...) + no_center_view = dataset.exclude_group_slices("center") + + assert no_center_view.media_type == "group" + assert no_center_view.group_slices == ["left", "right"] .. _groups-aggregations: diff --git a/fiftyone/__public__.py b/fiftyone/__public__.py index f343bdf718..ff4625b9da 100644 --- a/fiftyone/__public__.py +++ b/fiftyone/__public__.py @@ -198,6 +198,7 @@ ExcludeFields, ExcludeFrames, ExcludeGroups, + ExcludeGroupSlices, ExcludeLabels, Exists, FilterField, diff --git a/fiftyone/core/collections.py b/fiftyone/core/collections.py index d53a25fce1..f92186265d 100644 --- a/fiftyone/core/collections.py +++ b/fiftyone/core/collections.py @@ -4618,6 +4618,80 @@ def exclude_groups(self, group_ids): """ return self._add_view_stage(fos.ExcludeGroups(group_ids)) + @view_stage + def exclude_group_slices(self, slices=None, media_type=None): + """Excludes the specified group slice(s) from the grouped collection. + + Examples:: + + import fiftyone as fo + + dataset = fo.Dataset() + dataset.add_group_field("group", default="ego") + + group1 = fo.Group() + group2 = fo.Group() + + dataset.add_samples( + [ + fo.Sample( + filepath="/path/to/left-image1.jpg", + group=group1.element("left"), + ), + fo.Sample( + filepath="/path/to/video1.mp4", + group=group1.element("ego"), + ), + fo.Sample( + filepath="/path/to/right-image1.jpg", + group=group1.element("right"), + ), + fo.Sample( + filepath="/path/to/left-image2.jpg", + group=group2.element("left"), + ), + fo.Sample( + filepath="/path/to/video2.mp4", + group=group2.element("ego"), + ), + fo.Sample( + filepath="/path/to/right-image2.jpg", + group=group2.element("right"), + ), + ] + ) + + # + # Exclude the samples from the "ego" group slice + # + + view = dataset.exclude_group_slices("ego") + + # + # Exclude the samples from the "left" or "right" group slices + # + + view = dataset.exclude_group_slices(["left", "right"]) + + # + # Exclude all image slices + # + + view = dataset.exclude_group_slices(media_type="image") + + Args: + slices (None): a group slice or iterable of group slices to + exclude + media_type (None): a media type or iterable of media types whose + slice(s) to exclude + + Returns: + a :class:`fiftyone.core.view.DatasetView` + """ + return self._add_view_stage( + fos.ExcludeGroupSlices(slices=slices, media_type=media_type) + ) + @view_stage def exclude_labels( self, labels=None, ids=None, tags=None, fields=None, omit_empty=True @@ -6594,19 +6668,24 @@ def select_group_slices( self, slices=None, media_type=None, + flat=True, _allow_mixed=False, _force_mixed=False, ): - """Selects the samples in the group collection from the given slice(s). + """Selects the specified group slice(s) from the grouped collection. + + When ``flat==True``, the returned view is a flattened non-grouped view + containing the samples from the slice(s) of interest. - The returned view is a flattened non-grouped view containing only the - slice(s) of interest. + When ``flat=False``, the returned view is a grouped collection + containing only the slice(s) of interest. .. note:: - This stage performs a ``$lookup`` that pulls the requested slice(s) - for each sample in the input collection from the source dataset. - As a result, this stage always emits *unfiltered samples*. + When ``flat=True``, this stage performs a ``$lookup`` that pulls + the requested slice(s) for each sample in the input collection from + the source dataset. As a result, the stage emits + *unfiltered samples*. Examples:: @@ -6659,6 +6738,12 @@ def select_group_slices( view = dataset.select_group_slices(["left", "right"]) + # + # Select only the "left" and "right" group slices + # + + view = dataset.select_group_slices(["left", "right"], flat=False) + # # Retrieve all image samples # @@ -6669,7 +6754,10 @@ def select_group_slices( slices (None): a group slice or iterable of group slices to select. If neither argument is provided, a flattened list of all samples is returned - media_type (None): a media type whose slice(s) to select + media_type (None): a media type or iterable of media types whose + slice(s) to select + flat (True): whether to return a flattened collection (True) or a + grouped collection (False) Returns: a :class:`fiftyone.core.view.DatasetView` @@ -6678,6 +6766,7 @@ def select_group_slices( fos.SelectGroupSlices( slices=slices, media_type=media_type, + flat=flat, _allow_mixed=_allow_mixed, _force_mixed=_force_mixed, ) @@ -10542,24 +10631,23 @@ def _contains_media_type(self, media_type, any_slice=False): return True if self.media_type == fom.GROUP: - if self.group_media_types is None: + group_media_types = self.group_media_types + if group_media_types is None: return self._dataset.media_type == media_type if any_slice: return any( slice_media_type == media_type - for slice_media_type in self.group_media_types.values() + for slice_media_type in group_media_types.values() ) - return ( - self.group_media_types.get(self.group_slice, None) - == media_type - ) + return group_media_types.get(self.group_slice, None) == media_type if self.media_type == fom.MIXED: + group_media_types = self._get_group_media_types() return any( slice_media_type == media_type - for slice_media_type in self._get_group_media_types().values() + for slice_media_type in group_media_types.values() ) return False diff --git a/fiftyone/core/dataset.py b/fiftyone/core/dataset.py index 9db20ee4c4..98d67b24fa 100644 --- a/fiftyone/core/dataset.py +++ b/fiftyone/core/dataset.py @@ -8508,7 +8508,6 @@ def _clone_collection(sample_collection, name, persistent): slug = _validate_dataset_name(name) contains_videos = sample_collection._contains_videos(any_slice=True) - contains_groups = sample_collection.media_type == fom.GROUP if isinstance(sample_collection, fov.DatasetView): dataset = sample_collection._dataset @@ -8516,6 +8515,9 @@ def _clone_collection(sample_collection, name, persistent): if view.media_type == fom.MIXED: raise ValueError("Cloning mixed views is not allowed") + + if view._is_dynamic_groups: + raise ValueError("Cloning dynamic grouped views is not allowed") else: dataset = sample_collection view = None @@ -8549,10 +8551,9 @@ def _clone_collection(sample_collection, name, persistent): dataset_doc.sample_collection_name = sample_collection_name dataset_doc.frame_collection_name = frame_collection_name dataset_doc.media_type = sample_collection.media_type - if not contains_groups: - dataset_doc.group_field = None - dataset_doc.group_media_types = {} - dataset_doc.default_group_slice = None + dataset_doc.group_field = sample_collection.group_field + dataset_doc.group_media_types = sample_collection.group_media_types + dataset_doc.default_group_slice = sample_collection.default_group_slice for field in dataset_doc.sample_fields: field._set_created_at(now) diff --git a/fiftyone/core/stages.py b/fiftyone/core/stages.py index eb29d5a942..ae016a0a5b 100644 --- a/fiftyone/core/stages.py +++ b/fiftyone/core/stages.py @@ -98,6 +98,18 @@ def outputs_dynamic_groups(self): """ return None + @property + def flattens_groups(self): + """Whether this stage flattens groups into a non-grouped collection. + + The possible return values are: + + - ``True``: this stage *flattens* groups + - ``False``: this stage *does not flatten* groups + - ``None``: this stage does not change group status + """ + return None + def get_edited_fields(self, sample_collection, frames=False): """Returns a list of names of fields or embedded fields that may have been edited by the stage, if any. @@ -210,6 +222,22 @@ def get_group_expr(self, sample_collection): """ return None, None + def get_group_media_types(self, sample_collection): + """Returns the group media types outputted by this stage, if any, when + applied to the given collection, if and only if they are different from + the input collection. + + Args: + sample_collection: the + :class:`fiftyone.core.collections.SampleCollection` to which + the stage is being applied + + Returns: + a dict mapping slice names to media types, or ``None`` if the stage + does not change the types + """ + return None + def load_view(self, sample_collection): """Loads the :class:`fiftyone.core.view.DatasetView` containing the output of the stage. @@ -4550,16 +4578,19 @@ def _params(cls): class SelectGroupSlices(ViewStage): - """Selects the samples in a group collection from the given slice(s). + """Selects the specified group slice(s) from a grouped collection. + + When ``flat==True``, the returned view is a flattened non-grouped view + containing the samples from the slice(s) of interest. - The returned view is a flattened non-grouped view containing only the - slice(s) of interest. + When ``flat=False``, the returned view is a grouped collection containing + only the slice(s) of interest. .. note:: - This stage performs a ``$lookup`` that pulls the requested slice(s) for - each sample in the input collection from the source dataset. As a - result, this stage always emits *unfiltered samples*. + When ``flat=True``, this stage performs a ``$lookup`` that pulls the + requested slice(s) for each sample in the input collection from the + source dataset. As a result, the stage emits *unfiltered samples*. Examples:: @@ -4614,6 +4645,13 @@ class SelectGroupSlices(ViewStage): stage = fo.SelectGroupSlices(["left", "right"]) view = dataset.add_stage(stage) + # + # Select only the "left" and "right" group slices + # + + stage = fo.SelectGroupSlices(["left", "right"], flat=False) + view = dataset.add_stage(stage) + # # Retrieve all image samples # @@ -4625,18 +4663,23 @@ class SelectGroupSlices(ViewStage): slices (None): a group slice or iterable of group slices to select. If neither argument is provided, a flattened list of all samples is returned - media_type (None): a media type whose slice(s) to select + media_type (None): a media type or iterable of media types whose + slice(s) to select + flat (True): whether to return a flattened collection (True) or a + grouped collection (False) """ def __init__( self, slices=None, media_type=None, + flat=True, _allow_mixed=False, _force_mixed=False, ): self._slices = slices self._media_type = media_type + self._flat = flat self._allow_mixed = _allow_mixed self._force_mixed = _force_mixed @@ -4647,10 +4690,22 @@ def slices(self): @property def media_type(self): - """The media type whose slices to select.""" + """The media type(s) whose slices to select.""" return self._media_type + @property + def flat(self): + """Whether to generate a flattened collection.""" + return self._flat + + @property + def flattens_groups(self): + return self._flat + def to_mongo(self, sample_collection): + if not self._flat: + return [] + if isinstance(sample_collection, fod.Dataset) or ( isinstance(sample_collection, fov.DatasetView) and len(sample_collection._stages) == 0 @@ -4677,6 +4732,13 @@ def _make_pipeline(self, sample_collection): name_field = group_field + ".name" slices = self._get_slices(sample_collection) + + # No $lookup needed because active slice is the only one requested + if ( + etau.is_str(slices) and slices == sample_collection.group_slice + ) or (slices is None and len(sample_collection.group_slices) == 1): + return [] + expr = F(id_field) == "$$group_id" if isinstance(slices, list): expr &= F(name_field).is_in(slices) @@ -4697,7 +4759,6 @@ def _make_pipeline(self, sample_collection): {"$replaceRoot": {"newRoot": "$groups"}}, ] - # @note(SelectGroupSlices) # Must re-apply field selection/exclusion after the $lookup if isinstance(sample_collection, fov.DatasetView): selected_fields, excluded_fields = _get_selected_excluded_fields( @@ -4715,6 +4776,9 @@ def _make_pipeline(self, sample_collection): return pipeline def get_media_type(self, sample_collection): + if not self._flat: + return sample_collection.media_type + if self._force_mixed: return fom.MIXED @@ -4776,15 +4840,25 @@ def validate(self, sample_collection): def _get_slices(self, sample_collection): if self._media_type is not None: + if etau.is_str(self._media_type): + media_types = {self._media_type} + else: + media_types = set(self._media_type) + group_media_types = sample_collection.group_media_types slices = [ slice_name for slice_name, media_type in group_media_types.items() - if media_type == self._media_type + if media_type in media_types ] else: slices = self._slices + if slices is None: + group_slices = sample_collection.group_slices + if group_slices != sample_collection._dataset.group_slices: + slices = group_slices + if not etau.is_container(slices): return slices @@ -4795,7 +4869,7 @@ def _get_slices(self, sample_collection): return slices - def _get_group_media_types(self, sample_collection): + def get_group_media_types(self, sample_collection): group_media_types = sample_collection.group_media_types slices = self._get_slices(sample_collection) @@ -4816,6 +4890,7 @@ def _kwargs(self): return [ ["slices", self._slices], ["media_type", self._media_type], + ["flat", self._flat], ["_allow_mixed", self._allow_mixed], ["_force_mixed", self._force_mixed], ] @@ -4831,10 +4906,16 @@ def _params(cls): }, { "name": "media_type", - "type": "NoneType|str", + "type": "NoneType|list|str", "placeholder": "media_type (default=None)", "default": "None", }, + { + "name": "flat", + "type": "bool", + "default": "True", + "placeholder": "flat (default=True)", + }, { "name": "_allow_mixed", "type": "NoneType|bool", @@ -4848,6 +4929,158 @@ def _params(cls): ] +class ExcludeGroupSlices(ViewStage): + """Excludes the specified group slice(s) from a grouped collection. + + Examples:: + + import fiftyone as fo + + dataset = fo.Dataset() + dataset.add_group_field("group", default="ego") + + group1 = fo.Group() + group2 = fo.Group() + + dataset.add_samples( + [ + fo.Sample( + filepath="/path/to/left-image1.jpg", + group=group1.element("left"), + ), + fo.Sample( + filepath="/path/to/video1.mp4", + group=group1.element("ego"), + ), + fo.Sample( + filepath="/path/to/right-image1.jpg", + group=group1.element("right"), + ), + fo.Sample( + filepath="/path/to/left-image2.jpg", + group=group2.element("left"), + ), + fo.Sample( + filepath="/path/to/video2.mp4", + group=group2.element("ego"), + ), + fo.Sample( + filepath="/path/to/right-image2.jpg", + group=group2.element("right"), + ), + ] + ) + + # + # Exclude the "ego" group slice + # + + stage = fo.ExcludeGroupSlices("ego") + view = dataset.add_stage(stage) + + # + # Exclude the "left" and "right" group slices + # + + stage = fo.ExcludeGroupSlices(["left", "right"]) + view = dataset.add_stage(stage) + + # + # Exclude all image slices + # + + stage = fo.ExcludeGroupSlices(media_type="image") + view = dataset.add_stage(stage) + + Args: + slices (None): a group slice or iterable of group slices to exclude. + media_type (None): a media type or iterable of media types whose + slice(s) to exclude + """ + + def __init__(self, slices=None, media_type=None): + self._slices = slices + self._media_type = media_type + + @property + def slices(self): + """The group slice(s) to exclude.""" + return self._slices + + @property + def media_type(self): + """The media type(s) whose slices to exclude.""" + return self._media_type + + def to_mongo(self, sample_collection): + return [] + + def validate(self, sample_collection): + if sample_collection.media_type != fom.GROUP: + raise ValueError("%s has no groups" % type(sample_collection)) + + def _get_slices(self, sample_collection): + if self._media_type is not None: + if etau.is_str(self._media_type): + media_types = {self._media_type} + else: + media_types = set(self._media_type) + + group_media_types = sample_collection.group_media_types + return [ + slice_name + for slice_name, media_type in group_media_types.items() + if media_type not in media_types + ] + + if self._slices is not None: + if etau.is_str(self._slices): + slices = {self._slices} + else: + slices = set(self._slices) + + return [ + slice_name + for slice_name in sample_collection.group_slices + if slice_name not in slices + ] + + return sample_collection.group_slices + + def get_group_media_types(self, sample_collection): + group_media_types = sample_collection.group_media_types + slices = set(self._get_slices(sample_collection)) + + return { + slice_name: media_type + for slice_name, media_type in group_media_types.items() + if slice_name in slices + } + + def _kwargs(self): + return [ + ["slices", self._slices], + ["media_type", self._media_type], + ] + + @classmethod + def _params(cls): + return [ + { + "name": "slices", + "type": "NoneType|list|str", + "placeholder": "slices (default=None)", + "default": "None", + }, + { + "name": "media_type", + "type": "NoneType|list|str", + "placeholder": "media_type (default=None)", + "default": "None", + }, + ] + + class MatchFrames(ViewStage): """Filters the frames in a video collection by the given filter. @@ -8611,6 +8844,7 @@ def repr_ViewExpression(self, expr, level): ExcludeFields, ExcludeFrames, ExcludeGroups, + ExcludeGroupSlices, ExcludeLabels, Exists, FilterField, @@ -8658,6 +8892,7 @@ def repr_ViewExpression(self, expr, level): # View stages that only select documents Exclude, ExcludeBy, + ExcludeGroupSlices, Exists, GeoNear, GeoWithin, @@ -8674,5 +8909,6 @@ def repr_ViewExpression(self, expr, level): # Registry of select stages that should select first _STAGES_THAT_SELECT_FIRST = { + ExcludeGroupSlices, SelectGroupSlices, } diff --git a/fiftyone/core/view.py b/fiftyone/core/view.py index a7ba59c983..cb277ad0b5 100644 --- a/fiftyone/core/view.py +++ b/fiftyone/core/view.py @@ -162,7 +162,7 @@ def _has_slices(self): return False for stage in self._stages: - if isinstance(stage, fost.SelectGroupSlices): + if stage.flattens_groups: return False return True @@ -244,7 +244,7 @@ def group_slice(self): return None if self.__group_slice is not None: - return self.__group_slice + return self.__group_slice or None return self._dataset.group_slice @@ -271,7 +271,7 @@ def group_slices(self): if not self._has_slices: return None - return self._dataset.group_slices + return list(self._get_group_media_types().keys()) @property def group_media_types(self): @@ -281,7 +281,7 @@ def group_media_types(self): if not self._has_slices: return None - return self._dataset.group_media_types + return self._get_group_media_types() @property def default_group_slice(self): @@ -291,6 +291,9 @@ def default_group_slice(self): if not self._has_slices: return None + if self._dataset.default_group_slice not in self.group_slices: + return self.group_slice + return self._dataset.default_group_slice @property @@ -1557,7 +1560,7 @@ def _pipeline( _view = self._base_view _contains_videos = self._dataset._contains_videos(any_slice=True) - _found_select_group_slice = False + _found_flattened_videos = False _attach_frames_idx = None _attach_frames_idx0 = None _attach_frames_idx1 = None @@ -1573,10 +1576,12 @@ def _pipeline( idx = 0 for stage in self._stages: - if isinstance(stage, fost.SelectGroupSlices): - # We might need to reattach frames after `SelectGroupSlices`, - # since it involves a `$lookup` that resets the samples - _found_select_group_slice = True + _pipeline = stage.to_mongo(_view) + + if stage.flattens_groups and _contains_videos and _pipeline: + # We might need to reattach frames after flattening groups + # since this involves a `$lookup` that resets the samples + _found_flattened_videos = True _attach_frames_idx0 = _attach_frames_idx _attach_frames_idx = None @@ -1604,15 +1609,12 @@ def _pipeline( _group_slices.update(_stage_group_slices) - _pipeline = stage.to_mongo(_view) - - # @note(SelectGroupSlices) - # Special case: when selecting group slices of a video dataset that + # Special case: when flattening group slices of a video dataset that # modifies the dataset's schema, frame lookups must be injected in # the middle of the stage's pipeline, after the group slice $lookup # but *before* the $project stage(s) that reapply schema changes if ( - isinstance(stage, fost.SelectGroupSlices) + stage.flattens_groups and _contains_videos and _pipeline and "$project" in _pipeline[-1] @@ -1646,8 +1648,8 @@ def _pipeline( attach_frames = True _pipeline = self._dataset._attach_frames_pipeline(support=support) _pipelines.insert(_attach_frames_idx, _pipeline) - elif _found_select_group_slice and _attach_frames_idx is not None: - # Must manually attach frames after the group selection + elif _found_flattened_videos and _attach_frames_idx is not None: + # Must manually attach frames after the group $lookup attach_frames = None # special syntax: frames already attached _pipeline = self._dataset._attach_frames_pipeline(support=support) _pipelines.insert(_attach_frames_idx, _pipeline) @@ -1699,7 +1701,7 @@ def _pipeline( media_type = self.media_type if group_slice is None and self._dataset.media_type == fom.GROUP: - group_slice = self.__group_slice or self._dataset.group_slice + group_slice = self.__group_slice return self._dataset._pipeline( pipeline=_pipeline, @@ -1812,6 +1814,15 @@ def _add_view_stage(self, stage, validate=True): if media_type is not None: view._set_media_type(media_type) + group_media_types = stage.get_group_media_types(self) + if ( + group_media_types is not None + and view.media_type == fom.GROUP + and view.group_slice not in group_media_types + ): + group_slice = next(iter(group_media_types.keys()), "") + view._set_group_slice(group_slice) + view._set_name(None) return view @@ -1819,6 +1830,9 @@ def _add_view_stage(self, stage, validate=True): def _set_media_type(self, media_type): self.__media_type = media_type + def _set_group_slice(self, slice_name): + self.__group_slice = slice_name + def _set_name(self, name): self.__name = name @@ -1922,11 +1936,17 @@ def _get_missing_fields(self, frames=False): return missing_fields def _get_group_media_types(self): - for stage in reversed(self._stages): - if isinstance(stage, fost.SelectGroupSlices): - return stage._get_group_media_types(self._dataset) + group_media_types = self._dataset.group_media_types + + _view = self._base_view + for stage in self._stages: + gmt = stage.get_group_media_types(_view) + if gmt is not None: + group_media_types = gmt + + _view = _view._add_view_stage(stage, validate=False) - return self._dataset.group_media_types + return group_media_types def make_optimized_select_view( @@ -1953,7 +1973,7 @@ def make_optimized_select_view( match the order of the provided IDs groups (False): whether the IDs are group IDs, not sample IDs flatten (False): whether to flatten group datasets before selecting - sample ids + sample IDs Returns: a :class:`DatasetView` diff --git a/tests/unittests/group_tests.py b/tests/unittests/group_tests.py index fbc66af2ba..b3163fe7a2 100644 --- a/tests/unittests/group_tests.py +++ b/tests/unittests/group_tests.py @@ -588,6 +588,18 @@ def test_field_schemas(self): view = dataset.select_fields() + # Selecting active slice maintains schema changes + video_view = view.select_group_slices("ego") + + self.assertEqual(view.group_slice, "ego") + self.assertEqual(video_view.media_type, "video") + self.assertNotIn("field", video_view.get_field_schema()) + self.assertNotIn("field", video_view.get_frame_field_schema()) + for sample in video_view: + self.assertFalse(sample.has_field("field")) + for frame in sample.frames.values(): + self.assertFalse(frame.has_field("field")) + # Cloning a grouped dataset maintains schema changes group_dataset = view.clone() @@ -611,7 +623,6 @@ def test_field_schemas(self): for sample in image_dataset: self.assertFalse(sample.has_field("field")) - # @note(SelectGroupSlices) # Selecting group slices maintains frame schema changes video_view = view.select_group_slices(media_type="video") @@ -631,6 +642,126 @@ def test_field_schemas(self): for frame in sample.frames.values(): self.assertFalse(frame.has_field("field")) + @drop_datasets + def test_select_exclude_slices(self): + dataset = _make_group_dataset() + + # Select slices by name + view = dataset.select_group_slices(["left", "right"], flat=False) + + self.assertEqual(len(view), 2) + self.assertEqual(view.media_type, "group") + self.assertSetEqual(set(view.group_slices), {"left", "right"}) + self.assertDictEqual( + view.group_media_types, {"left": "image", "right": "image"} + ) + self.assertIn(view.group_slice, ["left", "right"]) + self.assertIn(view.default_group_slice, ["left", "right"]) + + # Select slices by media type + view = dataset.select_group_slices(media_type="image", flat=False) + + self.assertEqual(len(view), 2) + self.assertEqual(view.media_type, "group") + self.assertSetEqual(set(view.group_slices), {"left", "right"}) + self.assertDictEqual( + view.group_media_types, {"left": "image", "right": "image"} + ) + self.assertIn(view.group_slice, ["left", "right"]) + self.assertIn(view.default_group_slice, ["left", "right"]) + + # Exclude slices by name + view = dataset.exclude_group_slices("ego") + + self.assertEqual(len(view), 2) + self.assertEqual(view.media_type, "group") + self.assertSetEqual(set(view.group_slices), {"left", "right"}) + self.assertDictEqual( + view.group_media_types, {"left": "image", "right": "image"} + ) + self.assertIn(view.group_slice, ["left", "right"]) + self.assertIn(view.default_group_slice, ["left", "right"]) + + # Exclude slices by media type + view = dataset.exclude_group_slices(media_type="video") + + self.assertEqual(len(view), 2) + self.assertEqual(view.media_type, "group") + self.assertSetEqual(set(view.group_slices), {"left", "right"}) + self.assertDictEqual( + view.group_media_types, {"left": "image", "right": "image"} + ) + self.assertIn(view.group_slice, ["left", "right"]) + self.assertIn(view.default_group_slice, ["left", "right"]) + + # Empty grouped view + view = dataset.select_group_slices( + ["left", "right"], flat=False + ).exclude_group_slices(media_type="image") + + self.assertEqual(len(view), 0) + self.assertEqual(view.media_type, "group") + self.assertListEqual(view.group_slices, []) + self.assertDictEqual(view.group_media_types, {}) + self.assertIsNone(view.group_slice) + self.assertIsNone(view.default_group_slice) + + # Empty grouped view clone + dataset2 = view.clone() + + self.assertEqual(len(dataset2), 0) + self.assertEqual(dataset2.media_type, "group") + self.assertListEqual(dataset2.group_slices, []) + self.assertDictEqual(dataset2.group_media_types, {}) + self.assertIsNone(dataset2.group_slice) + self.assertIsNone(dataset2.default_group_slice) + + # Select group slices with filtered schema + view = dataset.select_fields().select_group_slices( + media_type="video", flat=False + ) + + self.assertEqual(len(view), 2) + self.assertEqual(view.media_type, "group") + self.assertListEqual(view.group_slices, ["ego"]) + self.assertDictEqual(view.group_media_types, {"ego": "video"}) + self.assertEqual(view.group_slice, "ego") + self.assertEqual(view.default_group_slice, "ego") + + schema = view.get_field_schema() + frame_schema = view.get_frame_field_schema() + + self.assertNotIn("field", schema) + self.assertNotIn("field", frame_schema) + + sample_view = view.first() + frame_view = sample_view.frames.first() + + self.assertFalse(sample_view.has_field("field")) + self.assertFalse(frame_view.has_field("field")) + + # Clone selected group slices with filtered schema + dataset2 = view.clone() + + self.assertEqual(len(dataset2), 2) + self.assertEqual(dataset2.media_type, "group") + self.assertListEqual(dataset2.group_slices, ["ego"]) + self.assertDictEqual(dataset2.group_media_types, {"ego": "video"}) + self.assertEqual(dataset2.group_slice, "ego") + self.assertEqual(dataset2.default_group_slice, "ego") + + schema = dataset2.get_field_schema() + frame_schema = dataset2.get_frame_field_schema() + + self.assertNotIn("field", schema) + self.assertNotIn("field", frame_schema) + + sample2 = dataset2.first() + frame2 = sample2.frames.first() + + self.assertFalse(sample2.has_field("field")) + self.assertFalse(frame2.has_field("field")) + @drop_datasets def test_attached_groups(self): dataset = _make_group_dataset()