Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add included to .get() & .query() response #2044

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def _get(
documents=body.get("documents", None),
data=None,
uris=body.get("uris", None),
included=body["included"],
)

@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -581,6 +582,7 @@ def _query(
documents=body.get("documents", None),
uris=body.get("uris", None),
data=None,
included=body["included"],
)

@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
Expand Down
3 changes: 3 additions & 0 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def _get(
documents=[] if "documents" in include else None,
uris=[] if "uris" in include else None,
data=[] if "data" in include else None,
included=include,
)

vectors: Sequence[t.VectorEmbeddingRecord] = []
Expand Down Expand Up @@ -574,6 +575,7 @@ def _get(
documents=documents if "documents" in include else None, # type: ignore
uris=uris if "uris" in include else None, # type: ignore
data=None,
included=include,
)

@trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -766,6 +768,7 @@ def _query(
documents=documents if documents else None,
uris=uris if uris else None,
data=None,
included=include,
)

@trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION)
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class GetResult(TypedDict):
uris: Optional[URIs]
data: Optional[Loadable]
metadatas: Optional[List[Metadata]]
included: Include


class QueryResult(TypedDict):
Expand All @@ -167,6 +168,7 @@ class QueryResult(TypedDict):
data: Optional[List[Loadable]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
included: Include


class IndexMetadata(TypedDict):
Expand Down
2 changes: 2 additions & 0 deletions chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def test_empty_filter(api: ServerAPI) -> None:
assert res["embeddings"] == [[]]
assert res["distances"] == [[]]
assert res["metadatas"] == [[]]
assert set(res["included"]) == set(["embeddings", "distances", "metadatas"])

res = coll.query(
query_embeddings=test_query_embeddings,
Expand All @@ -348,6 +349,7 @@ def test_empty_filter(api: ServerAPI) -> None:
assert res["embeddings"] is None
assert res["distances"] == [[], []]
assert res["metadatas"] == [[], []]
assert set(res["included"]) == set(["metadatas", "documents", "distances"])


def test_boolean_metadata(api: ServerAPI) -> None:
Expand Down
39 changes: 35 additions & 4 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def test_persist_index_loading(api_fixture, request):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -118,6 +120,8 @@ def __call__(self, input):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -146,6 +150,8 @@ def __call__(self, input):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -260,6 +266,8 @@ def test_get_from_db(api):
for key in records.keys():
if (key in includes) or (key == "ids"):
assert len(records[key]) == 2
elif key == "included":
assert set(records[key]) == set(includes)
else:
assert records[key] is None

Expand Down Expand Up @@ -290,6 +298,8 @@ def test_get_nearest_neighbors(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand All @@ -302,6 +312,8 @@ def test_get_nearest_neighbors(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand All @@ -314,6 +326,8 @@ def test_get_nearest_neighbors(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 2
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -437,6 +451,8 @@ def test_increment_index_on(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -489,6 +505,8 @@ def test_peek(api):
for key in peek.keys():
if key in ["embeddings", "documents", "metadatas"] or key == "ids":
assert len(peek[key]) == 2
elif key == "included":
assert set(peek[key]) == set(["embeddings", "metadatas", "documents"])
else:
assert peek[key] is None

Expand Down Expand Up @@ -994,22 +1012,26 @@ def test_query_include(api):
collection = api.create_collection("test_query_include")
collection.add(**records)

include = ["metadatas", "documents", "distances"]
items = collection.query(
query_embeddings=[0, 0, 0],
include=["metadatas", "documents", "distances"],
include=include,
n_results=1,
)
assert items["embeddings"] is None
assert items["ids"][0][0] == "id1"
assert items["metadatas"][0][0]["int_value"] == 1
assert set(items["included"]) == set(include)

include = ["embeddings", "documents", "distances"]
items = collection.query(
query_embeddings=[0, 0, 0],
include=["embeddings", "documents", "distances"],
include=include,
n_results=1,
)
assert items["metadatas"] is None
assert items["ids"][0][0] == "id1"
assert set(items["included"]) == set(include)

items = collection.query(
query_embeddings=[[0, 0, 0], [1, 2, 1.2]],
Expand All @@ -1029,22 +1051,27 @@ def test_get_include(api):
collection = api.create_collection("test_get_include")
collection.add(**records)

items = collection.get(include=["metadatas", "documents"], where={"int_value": 1})
include = ["metadatas", "documents"]
items = collection.get(include=include, where={"int_value": 1})
assert items["embeddings"] is None
assert items["ids"][0] == "id1"
assert items["metadatas"][0]["int_value"] == 1
assert items["documents"][0] == "this document is first"
assert set(items["included"]) == set(include)

items = collection.get(include=["embeddings", "documents"])
include = ["embeddings", "documents"]
items = collection.get(include=include)
assert items["metadatas"] is None
assert items["ids"][0] == "id1"
assert approx_equal(items["embeddings"][1][0], 1.2)
assert set(items["included"]) == set(include)

items = collection.get(include=[])
assert items["documents"] is None
assert items["metadatas"] is None
assert items["embeddings"] is None
assert items["ids"][0] == "id1"
assert items["included"] == []

with pytest.raises(ValueError, match="include"):
items = collection.get(include=["metadatas", "undefined"])
Expand Down Expand Up @@ -1172,6 +1199,8 @@ def test_persist_index_loading_params(api, request):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -1290,6 +1319,8 @@ def test_get_nearest_neighbors_where_n_results_more_than_element(api):
for key in results.keys():
if key in includes or key == "ids":
assert len(results[key][0]) == 2
elif key == "included":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh for future reference, this is really a grab bag of tests, we try to pull tests out on-demand, we can leave for now

assert set(results[key]) == set(includes)
else:
assert results[key] is None

Expand Down
12 changes: 7 additions & 5 deletions clients/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ type WhereOperator = "$gt" | "$gte" | "$lt" | "$lte" | "$ne" | "$eq";

type OperatorExpression = {
[key in WhereOperator | InclusionOperator | LogicalOperator]?:
| LiteralValue
| ListLiteralValue;
| LiteralValue
| ListLiteralValue;
};

type BaseWhere = {
Expand All @@ -50,9 +50,9 @@ type WhereDocumentOperator = "$contains" | "$not_contains" | LogicalOperator;

export type WhereDocument = {
[key in WhereDocumentOperator]?:
| LiteralValue
| LiteralNumber
| WhereDocument[];
| LiteralValue
| LiteralNumber
| WhereDocument[];
};

export type CollectionType = {
Expand All @@ -67,6 +67,7 @@ export type GetResponse = {
documents: (null | Document)[];
metadatas: (null | Metadata)[];
error: null | string;
included: IncludeEnum[]
};

export type QueryResponse = {
Expand All @@ -75,6 +76,7 @@ export type QueryResponse = {
documents: (null | Document)[][];
metadatas: (null | Metadata)[][];
distances: null | number[][];
included: IncludeEnum[]
};

export type AddResponse = {
Expand Down
1 change: 1 addition & 0 deletions clients/js/test/get.collection.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ test("it should get a collection", async () => {
expect(results.ids.length).toBe(1);
expect(["test1"]).toEqual(expect.arrayContaining(results.ids));
expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids));
expect(results.included).toEqual(expect.arrayContaining(["metadatas", "documents"]))

const results2 = await collection.get({ where: { test: "test1" } });
expect(results2).toBeDefined();
Expand Down
4 changes: 3 additions & 1 deletion clients/js/test/query.collection.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { EMBEDDINGS, IDS, METADATAS, DOCUMENTS } from "./data";
import { IEmbeddingFunction } from "../src/embeddings/IEmbeddingFunction";

export class TestEmbeddingFunction implements IEmbeddingFunction {
constructor() {}
constructor() { }

public async generate(texts: string[]): Promise<number[][]> {
let embeddings: number[][] = [];
Expand All @@ -29,6 +29,7 @@ test("it should query a collection", async () => {
expect(results).toBeInstanceOf(Object);
expect(["test1", "test2"]).toEqual(expect.arrayContaining(results.ids[0]));
expect(["test3"]).not.toEqual(expect.arrayContaining(results.ids[0]));
expect(results.included).toEqual(expect.arrayContaining(["metadatas", "documents"]))
});

// test where_document
Expand Down Expand Up @@ -68,6 +69,7 @@ test("it should get embedding with matching documents", async () => {
// expect(results2.embeddings[0][0]).toBeInstanceOf(Array);
expect(results2.embeddings![0].length).toBe(1);
expect(results2.embeddings![0][0]).toEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
expect(results2.included).toEqual(expect.arrayContaining(["embeddings"]))
});

test("it should exclude documents matching - not_contains", async () => {
Expand Down
Loading