diff --git a/fiftyone/server/view.py b/fiftyone/server/view.py index 98ea785bd7..7a8d4b9464 100644 --- a/fiftyone/server/view.py +++ b/fiftyone/server/view.py @@ -268,7 +268,13 @@ def handle_group_filter( view = view.match( {group_field + ".name": {"$in": filter.slices}} ) - view = view._add_view_stage(stage, validate=False) + + # if selecting a group, filter out select/reorder stages + if ( + not filter.id + or type(stage) not in fosg._STAGES_THAT_SELECT_OR_REORDER + ): + view = view._add_view_stage(stage, validate=False) elif filter.id: view = fov.make_optimized_select_view(view, filter.id, groups=True) diff --git a/tests/unittests/server_group_tests.py b/tests/unittests/server_group_tests.py index b368c57159..3c981ca670 100644 --- a/tests/unittests/server_group_tests.py +++ b/tests/unittests/server_group_tests.py @@ -90,3 +90,34 @@ def test_manual_group_slice(self): GroupElementFilter(), ) self.assertEqual(view._all_stages, [fo.Select(first)]) + + @drop_datasets + def test_group_selection(self): + dataset: fo.Dataset = fo.Dataset() + group = fo.Group() + one = fo.Sample( + filepath="image.png", + group=group.element("one"), + ) + two = fo.Sample( + filepath="image.png", + group=group.element("two"), + ) + + dataset.add_samples([one, two]) + + selection = dataset.select(one.id) + + with_slices, _ = fosv.handle_group_filter( + dataset, + selection, + GroupElementFilter(id=group.id, slices=["one", "two"]), + ) + self.assertEqual(len(with_slices), 2) + + without_slices, _ = fosv.handle_group_filter( + dataset, + selection, + GroupElementFilter(id=group.id, slices=["one", "two"]), + ) + self.assertEqual(len(without_slices), 2)