diff --git a/app/packages/relay/src/queries/__generated__/lightningQuery.graphql.ts b/app/packages/relay/src/queries/__generated__/lightningQuery.graphql.ts index 4749ee6299..e4510efa61 100644 --- a/app/packages/relay/src/queries/__generated__/lightningQuery.graphql.ts +++ b/app/packages/relay/src/queries/__generated__/lightningQuery.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<> + * @generated SignedSource<> * @lightSyntaxTransform * @nogrep */ @@ -12,6 +12,7 @@ import { ConcreteRequest, Query } from 'relay-runtime'; export type LightningInput = { dataset: string; paths: ReadonlyArray; + slice?: string | null; }; export type LightningPathInput = { exclude?: ReadonlyArray | null; diff --git a/app/packages/state/src/recoil/queryPerformance.ts b/app/packages/state/src/recoil/queryPerformance.ts index 8392074e29..439f167889 100644 --- a/app/packages/state/src/recoil/queryPerformance.ts +++ b/app/packages/state/src/recoil/queryPerformance.ts @@ -11,6 +11,7 @@ import { graphQLSelectorFamily } from "recoil-relay"; import type { ResponseFrom } from "../utils"; import { config } from "./config"; import { getBrowserStorageEffectForKey } from "./customEffects"; +import { groupSlice } from "./groups"; import { isLabelPath } from "./labels"; import { RelayEnvironmentKey } from "./relay"; import * as schemaAtoms from "./schema"; @@ -34,6 +35,7 @@ export const lightningQuery = graphQLSelectorFamily< input: { dataset: get(datasetName), paths, + slice: get(groupSlice), }, }; }, diff --git a/app/schema.graphql b/app/schema.graphql index 2b843f3a7a..f959dd13d2 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -444,6 +444,7 @@ input LabelTagColorInput { input LightningInput { dataset: String! paths: [LightningPathInput!]! + slice: String = null } input LightningPathInput { diff --git a/docs/source/user_guide/app.rst b/docs/source/user_guide/app.rst index 769f3fecb6..40b9bf0f59 100644 --- a/docs/source/user_guide/app.rst +++ b/docs/source/user_guide/app.rst @@ -489,8 +489,6 @@ perform initial filters on: # Note: it is faster to declare indexes before adding samples dataset.add_samples(...) - fo.app_config.default_query_performance = True - session = fo.launch_app(dataset) .. note:: @@ -521,8 +519,6 @@ compound index that includes the group slice name: dataset.create_index("ground_truth.detections.label") dataset.create_index([("group.name", 1), ("ground_truth.detections.label", 1)]) - fo.app_config.default_query_performance = True - session = fo.launch_app(dataset) For datasets with a small number of fields, you can index all fields by adding @@ -538,8 +534,6 @@ a single dataset = foz.load_zoo_dataset("quickstart") dataset.create_index("$**") - fo.app_config.default_query_performance = True - session = fo.launch_app(dataset) .. warning:: diff --git a/fiftyone/server/lightning.py b/fiftyone/server/lightning.py index 2b1d22df3d..701588864d 100644 --- a/fiftyone/server/lightning.py +++ b/fiftyone/server/lightning.py @@ -9,7 +9,6 @@ from bson import ObjectId from dataclasses import asdict, dataclass from datetime import date, datetime -import math import typing as t import asyncio @@ -46,6 +45,7 @@ class LightningPathInput: class LightningInput: dataset: str paths: t.List[LightningPathInput] + slice: t.Optional[str] = None @gql.interface @@ -138,7 +138,13 @@ async def lightning_resolver( for collection, sublist in zip(collections, queries) for item in sublist ] - result = await _do_async_pooled_queries(dataset, flattened) + + filter = ( + {f"{dataset.group_field}.name": input.slice} + if dataset.group_field and input.slice + else None + ) + result = await _do_async_pooled_queries(dataset, flattened, filter) results = [] offset = 0 @@ -293,10 +299,11 @@ async def _do_async_pooled_queries( queries: t.List[ t.Tuple[AsyncIOMotorCollection, t.Union[DistinctQuery, t.List[t.Dict]]] ], + filter: t.Optional[t.Mapping[str, str]], ): return await asyncio.gather( *[ - _do_async_query(dataset, collection, query) + _do_async_query(dataset, collection, query, filter) for collection, query in queries ] ) @@ -306,25 +313,31 @@ async def _do_async_query( dataset: fo.Dataset, collection: AsyncIOMotorCollection, query: t.Union[DistinctQuery, t.List[t.Dict]], + filter: t.Optional[t.Mapping[str, str]], ): if isinstance(query, DistinctQuery): if query.has_list and not query.filters: - return await _do_distinct_query(collection, query) + return await _do_distinct_query(collection, query, filter) + + return await _do_distinct_pipeline(dataset, collection, query, filter) - return await _do_distinct_pipeline(dataset, collection, query) + if filter: + query.insert(0, {"$match": filter}) return [i async for i in collection.aggregate(query)] async def _do_distinct_query( - collection: AsyncIOMotorCollection, query: DistinctQuery + collection: AsyncIOMotorCollection, + query: DistinctQuery, + filter: t.Optional[t.Mapping[str, str]], ): match = None if query.search: match = query.search try: - result = await collection.distinct(query.path) + result = await collection.distinct(query.path, filter) except: # too many results return None @@ -350,12 +363,16 @@ async def _do_distinct_pipeline( dataset: fo.Dataset, collection: AsyncIOMotorCollection, query: DistinctQuery, + filter: t.Optional[t.Mapping[str, str]], ): pipeline = [] + if filter: + pipeline.append({"$match": filter}) + if query.filters: pipeline += get_view(dataset, filters=query.filters)._pipeline() - pipeline += [{"$sort": {query.path: 1}}] + pipeline.append({"$sort": {query.path: 1}}) if query.search: if query.is_object_id_field: diff --git a/tests/unittests/lightning_tests.py b/tests/unittests/lightning_tests.py index 319315f89b..b631e8cf08 100644 --- a/tests/unittests/lightning_tests.py +++ b/tests/unittests/lightning_tests.py @@ -1053,6 +1053,91 @@ async def test_strings(self, dataset: fo.Dataset): ) +class TestGroupDatasetLightningQueries(unittest.IsolatedAsyncioTestCase): + @drop_async_dataset + async def test_group_dataset(self, dataset: fo.Dataset): + group = fo.Group() + one = fo.Sample( + classifications=fo.Classifications( + classifications=[fo.Classification(label="one")] + ), + filepath="one.png", + group=group.element("one"), + numeric=1, + string="one", + ) + two = fo.Sample( + classifications=fo.Classifications( + classifications=[fo.Classification(label="two")] + ), + filepath="two.png", + group=group.element("two"), + numeric=2, + string="two", + ) + dataset.add_samples([one, two]) + + query = """ + query Query($input: LightningInput!) { + lightning(input: $input) { + ... on IntLightningResult { + path + min + max + } + ... on StringLightningResult { + path + values + } + } + } + """ + + # only query "one" slice samples + result = await _execute( + query, + dataset, + (fo.IntField, fo.StringField), + ["classifications.classifications.label", "numeric", "string"], + frames=False, + slice="one", + ) + + self.assertListEqual( + result.data["lightning"], + [ + { + "path": "classifications.classifications.label", + "values": ["one"], + }, + {"path": "numeric", "min": 1.0, "max": 1.0}, + {"path": "string", "values": ["one"]}, + ], + ) + + # only query "two" slice samples + result = await _execute( + query, + dataset, + (fo.IntField, fo.StringField), + ["classifications.classifications.label", "numeric", "string"], + frames=False, + slice="two", + ) + + self.assertListEqual( + result.data["lightning"], + [ + { + "path": "classifications.classifications.label", + "values": ["two"], + }, + {"path": "numeric", "min": 2.0, "max": 2.0}, + {"path": "string", "values": ["two"]}, + ], + ) + + def _add_samples(dataset: fo.Dataset, *sample_data: t.List[t.Dict]): samples = [] keys = set() @@ -1067,7 +1152,12 @@ def _add_samples(dataset: fo.Dataset, *sample_data: t.List[t.Dict]): async def _execute( - query: str, dataset: fo.Dataset, field: fo.Field, keys: t.Set[str] + query: str, + dataset: fo.Dataset, + field: fo.Field, + keys: t.Set[str], + frames=True, + slice: t.Optional[str] = None, ): return await execute( schema, @@ -1076,7 +1166,8 @@ async def _execute( "input": asdict( LightningInput( dataset=dataset.name, - paths=_get_paths(dataset, field, keys), + paths=_get_paths(dataset, field, keys, frames=frames), + slice=slice, ) ) }, @@ -1084,17 +1175,23 @@ async def _execute( def _get_paths( - dataset: fo.Dataset, field_type: t.Type[fo.Field], keys: t.Set[str] + dataset: fo.Dataset, + field_type: t.Type[fo.Field], + keys: t.Set[str], + frames=True, ): field_dict = dataset.get_field_schema(flat=True) - field_dict.update( - **{ - f"frames.{path}": field - for path, field in dataset.get_frame_field_schema( - flat=True - ).items() - } - ) + + if frames: + field_dict.update( + **{ + f"frames.{path}": field + for path, field in dataset.get_frame_field_schema( + flat=True + ).items() + } + ) + paths: t.List[LightningPathInput] = [] for path in sorted(field_dict): field = field_dict[path]