Skip to content

Commit

Permalink
add selection filtering with slices in handle_group_filter
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminpkane committed Sep 11, 2024
1 parent c699f8d commit 41ed515
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
9 changes: 8 additions & 1 deletion fiftyone/server/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,18 @@ def handle_group_filter(

for stage in stages:
# add stages after flattening and group match

if group_by and isinstance(stage, fosg.GroupBy) and filter.slices:
view = view.match(
{group_field + ".name": {"$in": filter.slices}}
)
view = view._add_view_stage(stage, validate=False)

# if selecting a group, filter out select/reorder stages
if (
not filter.id
or type(stage) not in fosg._STAGES_THAT_SELECT_OR_REORDER
):
view = view._add_view_stage(stage, validate=False)

elif filter.id:
view = fov.make_optimized_select_view(view, filter.id, groups=True)
Expand Down
31 changes: 31 additions & 0 deletions tests/unittests/server_group_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,34 @@ def test_manual_group_slice(self):
GroupElementFilter(),
)
self.assertEqual(view._all_stages, [fo.Select(first)])

@drop_datasets
def test_group_selection(self):
dataset: fo.Dataset = fo.Dataset()
group = fo.Group()
one = fo.Sample(
filepath="image.png",
group=group.element("one"),
)
two = fo.Sample(
filepath="image.png",
group=group.element("two"),
)

dataset.add_samples([one, two])

selection = dataset.select(one.id)

with_slices, _ = fosv.handle_group_filter(
dataset,
selection,
GroupElementFilter(id=group.id, slices=["one", "two"]),
)
self.assertEqual(len(with_slices), 2)

without_slices, _ = fosv.handle_group_filter(
dataset,
selection,
GroupElementFilter(id=group.id, slices=["one", "two"]),
)
self.assertEqual(len(without_slices), 2)

0 comments on commit 41ed515

Please sign in to comment.