Skip to content

Commit

Permalink
[ENH]: CIP-4: In and Not In Metadata Filters (#1081)
Browse files Browse the repository at this point in the history
Cherry-picked from #1029

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Added support for `$in` and `$nin` metadata filters

> Note: See CIP in `docs/` or example notebook for more info

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python

## Documentation Changes
TBD

---------

Co-authored-by: Hammad Bashir <[email protected]>
  • Loading branch information
tazarov and HammadB authored Sep 5, 2023
1 parent 750f2ed commit 6dd2d4a
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 20 deletions.
34 changes: 29 additions & 5 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def validate_where(where: Where) -> Where:
if (
key != "$and"
and key != "$or"
and key != "$in"
and key != "$nin"
and not isinstance(value, (str, int, float, dict))
):
raise ValueError(
Expand Down Expand Up @@ -238,15 +240,37 @@ def validate_where(where: Where) -> Where:
raise ValueError(
f"Expected operand value to be an int or a float for operator {operator}, got {operand}"
)

if operator not in ["$gt", "$gte", "$lt", "$lte", "$ne", "$eq"]:
if operator in ["$in", "$nin"]:
if not isinstance(operand, list):
raise ValueError(
f"Expected operand value to be an list for operator {operator}, got {operand}"
)
if operator not in [
"$gt",
"$gte",
"$lt",
"$lte",
"$ne",
"$eq",
"$in",
"$nin",
]:
raise ValueError(
f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, got {operator}"
f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, "
f"got {operator}"
)

if not isinstance(operand, (str, int, float)):
if not isinstance(operand, (str, int, float, list)):
raise ValueError(
f"Expected where operand value to be a str, int, float, or list of those type, got {operand}"
)
if isinstance(operand, list) and (
len(operand) == 0
or not all(isinstance(x, type(operand[0])) for x in operand)
):
raise ValueError(
f"Expected where operand value to be a str, int, or float, got {operand}"
f"Expected where operand value to be a non-empty list, and all values to obe of the same type "
f"got {operand}"
)
return where

Expand Down
78 changes: 69 additions & 9 deletions chromadb/segment/impl/metadata/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Sequence, Any, Tuple, cast, Generator, Union, Dict
from typing import Optional, Sequence, Any, Tuple, cast, Generator, Union, Dict, List
from chromadb.segment import MetadataReader
from chromadb.ingest import Consumer
from chromadb.config import System
from chromadb.types import Segment
from chromadb.types import Segment, InclusionExclusionOperator
from chromadb.db.impl.sqlite import SqliteDB
from overrides import override
from chromadb.db.base import (
Expand Down Expand Up @@ -146,7 +146,6 @@ def get_metadata(

limit = limit or 2**63 - 1
offset = offset or 0

with self._db.tx() as cur:
return list(islice(self._records(cur, q), offset, offset + limit))

Expand Down Expand Up @@ -405,7 +404,6 @@ def _where_map_criterion(
self, q: QueryBuilder, where: Where, embeddings_t: Table, metadata_t: Table
) -> Criterion:
clause: list[Criterion] = []

for k, v in where.items():
if k == "$and":
criteria = [
Expand All @@ -419,8 +417,32 @@ def _where_map_criterion(
for w in cast(Sequence[Where], v)
]
clause.append(reduce(lambda x, y: x | y, criteria))
elif k == "$in":
expr = cast(
Dict[InclusionExclusionOperator, List[LiteralValue]], {k: v}
)
sq = (
self._db.querybuilder()
.from_(metadata_t)
.select(metadata_t.id)
.where(metadata_t.key.isin(ParameterValue(k)))
.where(_where_clause(expr, metadata_t))
)
clause.append(embeddings_t.id.isin(sq))
elif k == "$nin":
expr = cast(
Dict[InclusionExclusionOperator, List[LiteralValue]], {k: v}
)
sq = (
self._db.querybuilder()
.from_(metadata_t)
.select(metadata_t.id)
.where(metadata_t.key.notin(ParameterValue(k)))
.where(_where_clause(expr, metadata_t))
)
clause.append(embeddings_t.id.notin(sq))
else:
expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v)
expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v) # type: ignore
sq = (
self._db.querybuilder()
.from_(metadata_t)
Expand Down Expand Up @@ -492,24 +514,31 @@ def _decode_seq_id(seq_id_bytes: bytes) -> SeqId:


def _where_clause(
expr: Union[LiteralValue, Dict[WhereOperator, LiteralValue]],
expr: Union[
LiteralValue,
Dict[WhereOperator, LiteralValue],
Dict[InclusionExclusionOperator, List[LiteralValue]],
],
table: Table,
) -> Criterion:
"""Given a field name, an expression, and a table, construct a Pypika Criterion"""

# Literal value case
if isinstance(expr, (str, int, float, bool)):
return _where_clause({"$eq": expr}, table)
return _where_clause({cast(WhereOperator, "$eq"): expr}, table)

# Operator dict case
operator, value = next(iter(expr.items()))
return _value_criterion(value, operator, table)


def _value_criterion(value: LiteralValue, op: WhereOperator, table: Table) -> Criterion:
def _value_criterion(
value: Union[LiteralValue, List[LiteralValue]],
op: Union[WhereOperator, InclusionExclusionOperator],
table: Table,
) -> Criterion:
"""Return a criterion to compare a value with the appropriate columns given its type
and the operation type."""

if isinstance(value, str):
cols = [table.string_value]
# isinstance(True, int) evaluates to True, so we need to check for bools separately
Expand All @@ -519,6 +548,37 @@ def _value_criterion(value: LiteralValue, op: WhereOperator, table: Table) -> Cr
cols = [table.int_value]
elif isinstance(value, float) and op in ("$eq", "$ne"):
cols = [table.float_value]
elif isinstance(value, list) and op in ("$in", "$nin"):
_v = value
if len(_v) == 0:
raise ValueError(f"Empty list for {op} operator")
if isinstance(value[0], str):
col_exprs = [
table.string_value.isin(_v)
if op == "$in"
else table.str_value.notin(_v)
]
elif isinstance(value[0], bool):
col_exprs = [
table.bool_value.isin(_v) if op == "$in" else table.bool_value.notin(_v)
]
elif isinstance(value[0], int):
col_exprs = [
table.int_value.isin(_v) if op == "$in" else table.int_value.notin(_v)
]
elif isinstance(value[0], float):
col_exprs = [
table.float_value.isin(_v)
if op == "$in"
else table.float_value.notin(_v)
]
elif isinstance(value, list) and op in ("$in", "$nin"):
col_exprs = [
table.int_value.isin(value),
table.float_value.isin(value)
if op == "$in"
else table.float_value.notin(value),
]
else:
cols = [table.int_value, table.float_value]

Expand Down
31 changes: 30 additions & 1 deletion chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dataclasses import dataclass

from chromadb.api.types import Documents, Embeddings, Metadata
from chromadb.types import LiteralValue

# Set the random seed for reproducibility
np.random.seed(0) # unnecessary, hypothesis does this for us
Expand Down Expand Up @@ -448,6 +449,26 @@ def is_valid(self, rule) -> bool: # type: ignore
return True


def opposite_value(value: LiteralValue) -> SearchStrategy[Any]:
"""
Returns a strategy that will generate all valid values except the input value - testing of $nin
"""
if isinstance(value, float):
return st.floats(allow_nan=False, allow_infinity=False).filter(
lambda x: x != value
)
elif isinstance(value, str):
return safe_text.filter(lambda x: x != value)
elif isinstance(value, bool):
return st.booleans().filter(lambda x: x != value)
elif isinstance(value, int):
return st.integers(min_value=-(2**31), max_value=2**31 - 1).filter(
lambda x: x != value
)
else:
return st.from_type(type(value)).filter(lambda x: x != value)


@st.composite
def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
"""Generate a filter that could be used in a query against the given collection"""
Expand All @@ -457,7 +478,7 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
key = draw(st.sampled_from(known_keys))
value = collection.known_metadata_keys[key]

legal_ops: List[Optional[str]] = [None, "$eq", "$ne"]
legal_ops: List[Optional[str]] = [None, "$eq", "$ne", "$in", "$nin"]
if not isinstance(value, str) and not isinstance(value, bool):
legal_ops.extend(["$gt", "$lt", "$lte", "$gte"])
if isinstance(value, float):
Expand All @@ -468,6 +489,14 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:

if op is None:
return {key: value}
elif op == "$in":
if isinstance(value, str) and not value:
return {}
return {key: {op: [value, *[draw(opposite_value(value)) for _ in range(3)]]}}
elif op == "$nin":
if isinstance(value, str) and not value:
return {}
return {key: {op: [draw(opposite_value(value)) for _ in range(3)]}}
else:
return {key: {op: value}}

Expand Down
13 changes: 10 additions & 3 deletions chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ def _filter_where_clause(clause: Where, metadata: Metadata) -> bool:
if key == "$or":
assert isinstance(expr, list)
return any(_filter_where_clause(clause, metadata) for clause in expr)
if key == "$in":
assert isinstance(expr, list)
return metadata[key] in expr if key in metadata else False
if key == "$nin":
assert isinstance(expr, list)
return metadata[key] not in expr

# expr is an operator expression
assert isinstance(expr, dict)
op, val = list(expr.items())[0]

assert isinstance(metadata, dict)
if key not in metadata:
return False
Expand All @@ -55,6 +60,10 @@ def _filter_where_clause(clause: Where, metadata: Metadata) -> bool:
return key in metadata and metadata_key == val
elif op == "$ne":
return key in metadata and metadata_key != val
elif op == "$in":
return key in metadata and metadata_key in val
elif op == "$nin":
return key in metadata and metadata_key not in val

# The following conditions only make sense for numeric values
assert isinstance(metadata_key, int) or isinstance(metadata_key, float)
Expand Down Expand Up @@ -132,7 +141,6 @@ def _filter_embedding_set(
)
if not _filter_where_doc_clause(filter["where_document"], documents[i]):
ids.discard(normalized_record_set["ids"][i])

return list(ids)


Expand Down Expand Up @@ -174,7 +182,6 @@ def test_filterable_metadata_get(
return

coll.add(**record_set)

for filter in filters:
result_ids = coll.get(**filter)["ids"]
expected_ids = _filter_embedding_set(record_set, filter)
Expand Down
6 changes: 5 additions & 1 deletion chromadb/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ class VectorQueryResult(TypedDict):
Literal["$ne"],
Literal["$eq"],
]
OperatorExpression = Dict[Union[WhereOperator, LogicalOperator], LiteralValue]
InclusionExclusionOperator = Union[Literal["$in"], Literal["$nin"]]
OperatorExpression = Union[
Dict[Union[WhereOperator, LogicalOperator], LiteralValue],
Dict[InclusionExclusionOperator, List[LiteralValue]],
]

Where = Dict[
Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]]
Expand Down
2 changes: 1 addition & 1 deletion clients/js/test/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,5 +191,5 @@ test('wrong code returns an error', async () => {
// @ts-ignore - supposed to fail
const results = await collection.get({ where: { "test": { "$contains": "hello" } } });
expect(results.error).toBeDefined()
expect(results.error).toBe("ValueError('Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, got $contains')")
expect(results.error).toContain("ValueError('Expected where operator")
})
61 changes: 61 additions & 0 deletions docs/CIP_4_In_Nin_Metadata_Filters.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# CIP-4: In and Not In Metadata Filters Proposal

## Status

Current Status: `Under Discussion`

## **Motivation**

Currently, Chroma does not provide a way to filter metadata through `in` and `not in`. This appears to be a frequent ask
from community members.

## **Public Interfaces**

The changes will affect the following public interfaces:

- `Where` and `OperatorExpression`
classes - https://github.com/chroma-core/chroma/blob/48700dd07f14bcfd8b206dc3b2e2795d5531094d/chromadb/types.py#L125-L129
- `collection.get()`
- `collection.query()`

## **Proposed Changes**

We suggest the introduction of two new operators `$in` and `$nin` that will be used to filter metadata. We call these
operators `InclusionExclusionOperator`.

We suggest the following new operator definition:

```python
InclusionExclusionOperator = Union[Literal["$in"], Literal["$nin"]]
```

Additionally, we suggest that those operators are added to `OperatorExpression` for seamless integration with
existing `Where` semantics:

```python
OperatorExpression = Union[
Dict[Union[WhereOperator, LogicalOperator], LiteralValue],
Dict[InclusionExclusionOperator, List[LiteralValue]],
]
```

An example of a query using the new operators would be:

```python
collection.query(query_texts=query,
where={"$and": [{"author": {'$in': ['john', 'jill']}}, {"article_type": {"$eq": "blog"}}]},
n_results=3)
```

## **Compatibility, Deprecation, and Migration Plan**

The change is compatible with existing release 0.4.x.

## **Test Plan**

Property tests will be updated to ensure boundary conditions are covered as well as interoperability with existing `Where`
operators.

## **Rejected Alternatives**

N/A
Loading

0 comments on commit 6dd2d4a

Please sign in to comment.