Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QP sidebar filters to active slice for group datasets #5177

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions app/packages/state/src/recoil/queryPerformance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { graphQLSelectorFamily } from "recoil-relay";
import type { ResponseFrom } from "../utils";
import { config } from "./config";
import { getBrowserStorageEffectForKey } from "./customEffects";
import { groupSlice } from "./groups";
import { isLabelPath } from "./labels";
import { RelayEnvironmentKey } from "./relay";
import * as schemaAtoms from "./schema";
Expand All @@ -34,6 +35,7 @@ export const lightningQuery = graphQLSelectorFamily<
input: {
dataset: get(datasetName),
paths,
slice: get(groupSlice),
},
};
},
Expand Down
1 change: 1 addition & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ input LabelTagColorInput {
input LightningInput {
dataset: String!
paths: [LightningPathInput!]!
slice: String = null
}

input LightningPathInput {
Expand Down
6 changes: 0 additions & 6 deletions docs/source/user_guide/app.rst
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,6 @@ perform initial filters on:
# Note: it is faster to declare indexes before adding samples
dataset.add_samples(...)

fo.app_config.default_query_performance = True

session = fo.launch_app(dataset)

.. note::
Expand Down Expand Up @@ -521,8 +519,6 @@ compound index that includes the group slice name:
dataset.create_index("ground_truth.detections.label")
dataset.create_index([("group.name", 1), ("ground_truth.detections.label", 1)])

fo.app_config.default_query_performance = True

session = fo.launch_app(dataset)

For datasets with a small number of fields, you can index all fields by adding
Expand All @@ -538,8 +534,6 @@ a single
dataset = foz.load_zoo_dataset("quickstart")
dataset.create_index("$**")

fo.app_config.default_query_performance = True

session = fo.launch_app(dataset)

.. warning::
Expand Down
33 changes: 25 additions & 8 deletions fiftyone/server/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from bson import ObjectId
from dataclasses import asdict, dataclass
from datetime import date, datetime
import math
import typing as t

import asyncio
Expand Down Expand Up @@ -46,6 +45,7 @@ class LightningPathInput:
class LightningInput:
dataset: str
paths: t.List[LightningPathInput]
slice: t.Optional[str] = None


@gql.interface
Expand Down Expand Up @@ -138,7 +138,13 @@ async def lightning_resolver(
for collection, sublist in zip(collections, queries)
for item in sublist
]
result = await _do_async_pooled_queries(dataset, flattened)

filter = (
{f"{dataset.group_field}.name": input.slice}
if dataset.group_field and input.slice
else None
)
result = await _do_async_pooled_queries(dataset, flattened, filter)

results = []
offset = 0
Expand Down Expand Up @@ -293,10 +299,11 @@ async def _do_async_pooled_queries(
queries: t.List[
t.Tuple[AsyncIOMotorCollection, t.Union[DistinctQuery, t.List[t.Dict]]]
],
filter: t.Optional[t.Mapping[str, str]],
):
return await asyncio.gather(
*[
_do_async_query(dataset, collection, query)
_do_async_query(dataset, collection, query, filter)
for collection, query in queries
]
)
Expand All @@ -306,25 +313,31 @@ async def _do_async_query(
dataset: fo.Dataset,
collection: AsyncIOMotorCollection,
query: t.Union[DistinctQuery, t.List[t.Dict]],
filter: t.Optional[t.Mapping[str, str]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add type checking before modifying the query list

The current implementation assumes query is always a list when applying the filter. Add type checking to ensure safe operation.

    if filter:
+       if not isinstance(query, list):
+           raise TypeError("Expected query to be a list for filter application")
        query.insert(0, {"$match": filter})

Also applies to: 324-325

):
if isinstance(query, DistinctQuery):
if query.has_list and not query.filters:
return await _do_distinct_query(collection, query)
return await _do_distinct_query(collection, query, filter)

return await _do_distinct_pipeline(dataset, collection, query, filter)

return await _do_distinct_pipeline(dataset, collection, query)
if filter:
query.insert(0, {"$match": filter})

return [i async for i in collection.aggregate(query)]


async def _do_distinct_query(
collection: AsyncIOMotorCollection, query: DistinctQuery
collection: AsyncIOMotorCollection,
query: DistinctQuery,
filter: t.Optional[t.Mapping[str, str]],
):
match = None
if query.search:
match = query.search

try:
result = await collection.distinct(query.path)
result = await collection.distinct(query.path, filter)
except:
# too many results
return None
Expand All @@ -350,12 +363,16 @@ async def _do_distinct_pipeline(
dataset: fo.Dataset,
collection: AsyncIOMotorCollection,
query: DistinctQuery,
filter: t.Optional[t.Mapping[str, str]],
):
pipeline = []
if filter:
pipeline.append({"$match": filter})

if query.filters:
pipeline += get_view(dataset, filters=query.filters)._pipeline()

pipeline += [{"$sort": {query.path: 1}}]
pipeline.append({"$sort": {query.path: 1}})

if query.search:
if query.is_object_id_field:
Expand Down
119 changes: 108 additions & 11 deletions tests/unittests/lightning_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,91 @@ async def test_strings(self, dataset: fo.Dataset):
)


class TestGroupDatasetLightningQueries(unittest.IsolatedAsyncioTestCase):
@drop_async_dataset
async def test_group_dataset(self, dataset: fo.Dataset):
group = fo.Group()
one = fo.Sample(
classifications=fo.Classifications(
classifications=[fo.Classification(label="one")]
),
filepath="one.png",
group=group.element("one"),
numeric=1,
string="one",
)
two = fo.Sample(
classifications=fo.Classifications(
classifications=[fo.Classification(label="two")]
),
filepath="two.png",
group=group.element("two"),
numeric=2,
string="two",
)
dataset.add_samples([one, two])

query = """
query Query($input: LightningInput!) {
lightning(input: $input) {
... on IntLightningResult {
path
min
max
}
... on StringLightningResult {
path
values
}
}
}
"""

# only query "one" slice samples
result = await _execute(
query,
dataset,
(fo.IntField, fo.StringField),
["classifications.classifications.label", "numeric", "string"],
frames=False,
slice="one",
)

self.assertListEqual(
result.data["lightning"],
[
{
"path": "classifications.classifications.label",
"values": ["one"],
},
{"path": "numeric", "min": 1.0, "max": 1.0},
{"path": "string", "values": ["one"]},
],
)

# only query "two" slice samples
result = await _execute(
query,
dataset,
(fo.IntField, fo.StringField),
["classifications.classifications.label", "numeric", "string"],
frames=False,
slice="two",
)

self.assertListEqual(
result.data["lightning"],
[
{
"path": "classifications.classifications.label",
"values": ["two"],
},
{"path": "numeric", "min": 2.0, "max": 2.0},
{"path": "string", "values": ["two"]},
],
)


def _add_samples(dataset: fo.Dataset, *sample_data: t.List[t.Dict]):
samples = []
keys = set()
Expand All @@ -1067,7 +1152,12 @@ def _add_samples(dataset: fo.Dataset, *sample_data: t.List[t.Dict]):


async def _execute(
query: str, dataset: fo.Dataset, field: fo.Field, keys: t.Set[str]
query: str,
dataset: fo.Dataset,
field: fo.Field,
keys: t.Set[str],
frames=True,
slice: t.Optional[str] = None,
):
return await execute(
schema,
Expand All @@ -1076,25 +1166,32 @@ async def _execute(
"input": asdict(
LightningInput(
dataset=dataset.name,
paths=_get_paths(dataset, field, keys),
paths=_get_paths(dataset, field, keys, frames=frames),
slice=slice,
)
)
},
)


def _get_paths(
dataset: fo.Dataset, field_type: t.Type[fo.Field], keys: t.Set[str]
dataset: fo.Dataset,
field_type: t.Type[fo.Field],
keys: t.Set[str],
frames=True,
):
field_dict = dataset.get_field_schema(flat=True)
field_dict.update(
**{
f"frames.{path}": field
for path, field in dataset.get_frame_field_schema(
flat=True
).items()
}
)

if frames:
field_dict.update(
**{
f"frames.{path}": field
for path, field in dataset.get_frame_field_schema(
flat=True
).items()
}
)

paths: t.List[LightningPathInput] = []
for path in sorted(field_dict):
field = field_dict[path]
Expand Down
Loading