Skip to content

Commit

Permalink
cleanup, add slice tests to lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminpkane committed Nov 22, 2024
1 parent 8f1e6c4 commit 4fa08db
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 24 deletions.
7 changes: 4 additions & 3 deletions app/packages/state/src/recoil/queryPerformance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { graphQLSelectorFamily } from "recoil-relay";
import type { ResponseFrom } from "../utils";
import { config } from "./config";
import { getBrowserStorageEffectForKey } from "./customEffects";
import { groupSlice, groupStatistics } from "./groups";
import { groupSlice } from "./groups";
import { isLabelPath } from "./labels";
import { RelayEnvironmentKey } from "./relay";
import * as schemaAtoms from "./schema";
Expand All @@ -35,8 +35,7 @@ export const lightningQuery = graphQLSelectorFamily<
input: {
dataset: get(datasetName),
paths,
slice:
get(groupStatistics(false)) === "group" ? null : get(groupSlice),
slice: get(groupSlice),
},
};
},
Expand Down Expand Up @@ -86,6 +85,8 @@ const indexesByPath = selector({

const { sampleIndexes: samples, frameIndexes: frames } = get(indexes);

console.log(samples);

const schema = gatherPaths(State.SPACE.SAMPLE);
const frameSchema = gatherPaths(State.SPACE.FRAME).map((p) =>
p.slice("frames.".length)
Expand Down
6 changes: 0 additions & 6 deletions docs/source/user_guide/app.rst
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,6 @@ perform initial filters on:
# Note: it is faster to declare indexes before adding samples
dataset.add_samples(...)
fo.app_config.default_query_performance = True
session = fo.launch_app(dataset)
.. note::
Expand Down Expand Up @@ -521,8 +519,6 @@ compound index that includes the group slice name:
dataset.create_index("ground_truth.detections.label")
dataset.create_index([("group.name", 1), ("ground_truth.detections.label", 1)])
fo.app_config.default_query_performance = True
session = fo.launch_app(dataset)
For datasets with a small number of fields, you can index all fields by adding
Expand All @@ -538,8 +534,6 @@ a single
dataset = foz.load_zoo_dataset("quickstart")
dataset.create_index("$**")
fo.app_config.default_query_performance = True
session = fo.launch_app(dataset)
.. warning::
Expand Down
11 changes: 7 additions & 4 deletions fiftyone/server/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@ async def _do_async_query(
if query.has_list and not query.filters:
return await _do_distinct_query(collection, query, filter)

return await _do_distinct_pipeline(dataset, collection, filter)
return await _do_distinct_pipeline(dataset, collection, query, filter)

if filter:
query.insert(0, {"$match": filter})

return [i async for i in collection.aggregate(query)]

Expand Down Expand Up @@ -363,12 +366,12 @@ async def _do_distinct_pipeline(
filter: t.Optional[t.Mapping[str, str]],
):
pipeline = []
if query.filters:
pipeline += get_view(dataset, filters=query.filters)._pipeline()

if filter:
pipeline.append({"$match": filter})

if query.filters:
pipeline += get_view(dataset, filters=query.filters)._pipeline()

pipeline.append({"$sort": {query.path: 1}})

if query.search:
Expand Down
119 changes: 108 additions & 11 deletions tests/unittests/lightning_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,91 @@ async def test_strings(self, dataset: fo.Dataset):
)


class TestGroupDatasetLightningQueries(unittest.IsolatedAsyncioTestCase):
@drop_async_dataset
async def test_group_dataset(self, dataset: fo.Dataset):
group = fo.Group()
one = fo.Sample(
classifications=fo.Classifications(
classifications=[fo.Classification(label="one")]
),
filepath="one.png",
group=group.element("one"),
numeric=1,
string="one",
)
two = fo.Sample(
classifications=fo.Classifications(
classifications=[fo.Classification(label="two")]
),
filepath="two.png",
group=group.element("two"),
numeric=2,
string="two",
)
dataset.add_samples([one, two])

query = """
query Query($input: LightningInput!) {
lightning(input: $input) {
... on IntLightningResult {
path
min
max
}
... on StringLightningResult {
path
values
}
}
}
"""

# only query "one" slice samples
result = await _execute(
query,
dataset,
(fo.IntField, fo.StringField),
["classifications.classifications.label", "numeric", "string"],
frames=False,
slice="one",
)

self.assertListEqual(
result.data["lightning"],
[
{
"path": "classifications.classifications.label",
"values": ["one"],
},
{"path": "numeric", "min": 1.0, "max": 1.0},
{"path": "string", "values": ["one"]},
],
)

# only query "two" slice samples
result = await _execute(
query,
dataset,
(fo.IntField, fo.StringField),
["classifications.classifications.label", "numeric", "string"],
frames=False,
slice="two",
)

self.assertListEqual(
result.data["lightning"],
[
{
"path": "classifications.classifications.label",
"values": ["two"],
},
{"path": "numeric", "min": 2.0, "max": 2.0},
{"path": "string", "values": ["two"]},
],
)


def _add_samples(dataset: fo.Dataset, *sample_data: t.List[t.Dict]):
samples = []
keys = set()
Expand All @@ -1067,7 +1152,12 @@ def _add_samples(dataset: fo.Dataset, *sample_data: t.List[t.Dict]):


async def _execute(
query: str, dataset: fo.Dataset, field: fo.Field, keys: t.Set[str]
query: str,
dataset: fo.Dataset,
field: fo.Field,
keys: t.Set[str],
frames=True,
slice: t.Optional[str] = None,
):
return await execute(
schema,
Expand All @@ -1076,25 +1166,32 @@ async def _execute(
"input": asdict(
LightningInput(
dataset=dataset.name,
paths=_get_paths(dataset, field, keys),
paths=_get_paths(dataset, field, keys, frames=frames),
slice=slice,
)
)
},
)


def _get_paths(
dataset: fo.Dataset, field_type: t.Type[fo.Field], keys: t.Set[str]
dataset: fo.Dataset,
field_type: t.Type[fo.Field],
keys: t.Set[str],
frames=True,
):
field_dict = dataset.get_field_schema(flat=True)
field_dict.update(
**{
f"frames.{path}": field
for path, field in dataset.get_frame_field_schema(
flat=True
).items()
}
)

if frames:
field_dict.update(
**{
f"frames.{path}": field
for path, field in dataset.get_frame_field_schema(
flat=True
).items()
}
)

paths: t.List[LightningPathInput] = []
for path in sorted(field_dict):
field = field_dict[path]
Expand Down

0 comments on commit 4fa08db

Please sign in to comment.