From 6dd2d4af0bd34bdcce70c55f6ec52fee59492f3a Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 5 Sep 2023 20:42:01 +0300 Subject: [PATCH] [ENH]: CIP-4: In and Not In Metadata Filters (#1081) Cherry-picked from #1029 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Added support for `$in` and `$nin` metadata filters > Note: See CIP in `docs/` or example notebook for more info ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python ## Documentation Changes TBD --------- Co-authored-by: Hammad Bashir --- chromadb/api/types.py | 34 +++- chromadb/segment/impl/metadata/sqlite.py | 78 +++++++-- chromadb/test/property/strategies.py | 31 +++- chromadb/test/property/test_filtering.py | 13 +- chromadb/types.py | 6 +- clients/js/test/client.test.ts | 2 +- docs/CIP_4_In_Nin_Metadata_Filters.md | 61 +++++++ .../in_not_in_filtering.ipynb | 149 ++++++++++++++++++ 8 files changed, 354 insertions(+), 20 deletions(-) create mode 100644 docs/CIP_4_In_Nin_Metadata_Filters.md create mode 100644 examples/basic_functionality/in_not_in_filtering.ipynb diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 4b8e8863863..7979dba624e 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -207,6 +207,8 @@ def validate_where(where: Where) -> Where: if ( key != "$and" and key != "$or" + and key != "$in" + and key != "$nin" and not isinstance(value, (str, int, float, dict)) ): raise ValueError( @@ -238,15 +240,37 @@ def validate_where(where: Where) -> Where: raise ValueError( f"Expected operand value to be an int or a float for operator {operator}, got {operand}" ) - - if operator not in ["$gt", "$gte", "$lt", "$lte", "$ne", "$eq"]: + if operator in ["$in", "$nin"]: + if not isinstance(operand, list): + raise ValueError( + f"Expected operand value to be an list for operator {operator}, got {operand}" + ) + if operator not in [ + "$gt", + "$gte", + "$lt", + "$lte", + "$ne", + "$eq", + "$in", + "$nin", + ]: raise ValueError( - f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, got {operator}" + f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, " + f"got {operator}" ) - if not isinstance(operand, (str, int, float)): + if not isinstance(operand, (str, int, float, list)): + raise ValueError( + f"Expected where operand value to be a str, int, float, or list of those type, got {operand}" + ) + if isinstance(operand, list) and ( + len(operand) == 0 + or not all(isinstance(x, type(operand[0])) for x in operand) + ): raise ValueError( - f"Expected where operand value to be a str, int, or float, got {operand}" + f"Expected where operand value to be a non-empty list, and all values to obe of the same type " + f"got {operand}" ) return where diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index 7f33866ffed..781aed00ba4 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -1,8 +1,8 @@ -from typing import Optional, Sequence, Any, Tuple, cast, Generator, Union, Dict +from typing import Optional, Sequence, Any, Tuple, cast, Generator, Union, Dict, List from chromadb.segment import MetadataReader from chromadb.ingest import Consumer from chromadb.config import System -from chromadb.types import Segment +from chromadb.types import Segment, InclusionExclusionOperator from chromadb.db.impl.sqlite import SqliteDB from overrides import override from chromadb.db.base import ( @@ -146,7 +146,6 @@ def get_metadata( limit = limit or 2**63 - 1 offset = offset or 0 - with self._db.tx() as cur: return list(islice(self._records(cur, q), offset, offset + limit)) @@ -405,7 +404,6 @@ def _where_map_criterion( self, q: QueryBuilder, where: Where, embeddings_t: Table, metadata_t: Table ) -> Criterion: clause: list[Criterion] = [] - for k, v in where.items(): if k == "$and": criteria = [ @@ -419,8 +417,32 @@ def _where_map_criterion( for w in cast(Sequence[Where], v) ] clause.append(reduce(lambda x, y: x | y, criteria)) + elif k == "$in": + expr = cast( + Dict[InclusionExclusionOperator, List[LiteralValue]], {k: v} + ) + sq = ( + self._db.querybuilder() + .from_(metadata_t) + .select(metadata_t.id) + .where(metadata_t.key.isin(ParameterValue(k))) + .where(_where_clause(expr, metadata_t)) + ) + clause.append(embeddings_t.id.isin(sq)) + elif k == "$nin": + expr = cast( + Dict[InclusionExclusionOperator, List[LiteralValue]], {k: v} + ) + sq = ( + self._db.querybuilder() + .from_(metadata_t) + .select(metadata_t.id) + .where(metadata_t.key.notin(ParameterValue(k))) + .where(_where_clause(expr, metadata_t)) + ) + clause.append(embeddings_t.id.notin(sq)) else: - expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v) + expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v) # type: ignore sq = ( self._db.querybuilder() .from_(metadata_t) @@ -492,24 +514,31 @@ def _decode_seq_id(seq_id_bytes: bytes) -> SeqId: def _where_clause( - expr: Union[LiteralValue, Dict[WhereOperator, LiteralValue]], + expr: Union[ + LiteralValue, + Dict[WhereOperator, LiteralValue], + Dict[InclusionExclusionOperator, List[LiteralValue]], + ], table: Table, ) -> Criterion: """Given a field name, an expression, and a table, construct a Pypika Criterion""" # Literal value case if isinstance(expr, (str, int, float, bool)): - return _where_clause({"$eq": expr}, table) + return _where_clause({cast(WhereOperator, "$eq"): expr}, table) # Operator dict case operator, value = next(iter(expr.items())) return _value_criterion(value, operator, table) -def _value_criterion(value: LiteralValue, op: WhereOperator, table: Table) -> Criterion: +def _value_criterion( + value: Union[LiteralValue, List[LiteralValue]], + op: Union[WhereOperator, InclusionExclusionOperator], + table: Table, +) -> Criterion: """Return a criterion to compare a value with the appropriate columns given its type and the operation type.""" - if isinstance(value, str): cols = [table.string_value] # isinstance(True, int) evaluates to True, so we need to check for bools separately @@ -519,6 +548,37 @@ def _value_criterion(value: LiteralValue, op: WhereOperator, table: Table) -> Cr cols = [table.int_value] elif isinstance(value, float) and op in ("$eq", "$ne"): cols = [table.float_value] + elif isinstance(value, list) and op in ("$in", "$nin"): + _v = value + if len(_v) == 0: + raise ValueError(f"Empty list for {op} operator") + if isinstance(value[0], str): + col_exprs = [ + table.string_value.isin(_v) + if op == "$in" + else table.str_value.notin(_v) + ] + elif isinstance(value[0], bool): + col_exprs = [ + table.bool_value.isin(_v) if op == "$in" else table.bool_value.notin(_v) + ] + elif isinstance(value[0], int): + col_exprs = [ + table.int_value.isin(_v) if op == "$in" else table.int_value.notin(_v) + ] + elif isinstance(value[0], float): + col_exprs = [ + table.float_value.isin(_v) + if op == "$in" + else table.float_value.notin(_v) + ] + elif isinstance(value, list) and op in ("$in", "$nin"): + col_exprs = [ + table.int_value.isin(value), + table.float_value.isin(value) + if op == "$in" + else table.float_value.notin(value), + ] else: cols = [table.int_value, table.float_value] diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 6f855d99e96..e8540ef37aa 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from chromadb.api.types import Documents, Embeddings, Metadata +from chromadb.types import LiteralValue # Set the random seed for reproducibility np.random.seed(0) # unnecessary, hypothesis does this for us @@ -448,6 +449,26 @@ def is_valid(self, rule) -> bool: # type: ignore return True +def opposite_value(value: LiteralValue) -> SearchStrategy[Any]: + """ + Returns a strategy that will generate all valid values except the input value - testing of $nin + """ + if isinstance(value, float): + return st.floats(allow_nan=False, allow_infinity=False).filter( + lambda x: x != value + ) + elif isinstance(value, str): + return safe_text.filter(lambda x: x != value) + elif isinstance(value, bool): + return st.booleans().filter(lambda x: x != value) + elif isinstance(value, int): + return st.integers(min_value=-(2**31), max_value=2**31 - 1).filter( + lambda x: x != value + ) + else: + return st.from_type(type(value)).filter(lambda x: x != value) + + @st.composite def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: """Generate a filter that could be used in a query against the given collection""" @@ -457,7 +478,7 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: key = draw(st.sampled_from(known_keys)) value = collection.known_metadata_keys[key] - legal_ops: List[Optional[str]] = [None, "$eq", "$ne"] + legal_ops: List[Optional[str]] = [None, "$eq", "$ne", "$in", "$nin"] if not isinstance(value, str) and not isinstance(value, bool): legal_ops.extend(["$gt", "$lt", "$lte", "$gte"]) if isinstance(value, float): @@ -468,6 +489,14 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: if op is None: return {key: value} + elif op == "$in": + if isinstance(value, str) and not value: + return {} + return {key: {op: [value, *[draw(opposite_value(value)) for _ in range(3)]]}} + elif op == "$nin": + if isinstance(value, str) and not value: + return {} + return {key: {op: [draw(opposite_value(value)) for _ in range(3)]}} else: return {key: {op: value}} diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index d9ca874bf45..ddcdefb0ed3 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -42,11 +42,16 @@ def _filter_where_clause(clause: Where, metadata: Metadata) -> bool: if key == "$or": assert isinstance(expr, list) return any(_filter_where_clause(clause, metadata) for clause in expr) + if key == "$in": + assert isinstance(expr, list) + return metadata[key] in expr if key in metadata else False + if key == "$nin": + assert isinstance(expr, list) + return metadata[key] not in expr # expr is an operator expression assert isinstance(expr, dict) op, val = list(expr.items())[0] - assert isinstance(metadata, dict) if key not in metadata: return False @@ -55,6 +60,10 @@ def _filter_where_clause(clause: Where, metadata: Metadata) -> bool: return key in metadata and metadata_key == val elif op == "$ne": return key in metadata and metadata_key != val + elif op == "$in": + return key in metadata and metadata_key in val + elif op == "$nin": + return key in metadata and metadata_key not in val # The following conditions only make sense for numeric values assert isinstance(metadata_key, int) or isinstance(metadata_key, float) @@ -132,7 +141,6 @@ def _filter_embedding_set( ) if not _filter_where_doc_clause(filter["where_document"], documents[i]): ids.discard(normalized_record_set["ids"][i]) - return list(ids) @@ -174,7 +182,6 @@ def test_filterable_metadata_get( return coll.add(**record_set) - for filter in filters: result_ids = coll.get(**filter)["ids"] expected_ids = _filter_embedding_set(record_set, filter) diff --git a/chromadb/types.py b/chromadb/types.py index fd5c3709045..713cab7757c 100644 --- a/chromadb/types.py +++ b/chromadb/types.py @@ -122,7 +122,11 @@ class VectorQueryResult(TypedDict): Literal["$ne"], Literal["$eq"], ] -OperatorExpression = Dict[Union[WhereOperator, LogicalOperator], LiteralValue] +InclusionExclusionOperator = Union[Literal["$in"], Literal["$nin"]] +OperatorExpression = Union[ + Dict[Union[WhereOperator, LogicalOperator], LiteralValue], + Dict[InclusionExclusionOperator, List[LiteralValue]], +] Where = Dict[ Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]] diff --git a/clients/js/test/client.test.ts b/clients/js/test/client.test.ts index 5fbf2b09b8e..512237a2457 100644 --- a/clients/js/test/client.test.ts +++ b/clients/js/test/client.test.ts @@ -191,5 +191,5 @@ test('wrong code returns an error', async () => { // @ts-ignore - supposed to fail const results = await collection.get({ where: { "test": { "$contains": "hello" } } }); expect(results.error).toBeDefined() - expect(results.error).toBe("ValueError('Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, got $contains')") + expect(results.error).toContain("ValueError('Expected where operator") }) diff --git a/docs/CIP_4_In_Nin_Metadata_Filters.md b/docs/CIP_4_In_Nin_Metadata_Filters.md new file mode 100644 index 00000000000..e9a0911e69e --- /dev/null +++ b/docs/CIP_4_In_Nin_Metadata_Filters.md @@ -0,0 +1,61 @@ +# CIP-4: In and Not In Metadata Filters Proposal + +## Status + +Current Status: `Under Discussion` + +## **Motivation** + +Currently, Chroma does not provide a way to filter metadata through `in` and `not in`. This appears to be a frequent ask +from community members. + +## **Public Interfaces** + +The changes will affect the following public interfaces: + +- `Where` and `OperatorExpression` + classes - https://github.com/chroma-core/chroma/blob/48700dd07f14bcfd8b206dc3b2e2795d5531094d/chromadb/types.py#L125-L129 +- `collection.get()` +- `collection.query()` + +## **Proposed Changes** + +We suggest the introduction of two new operators `$in` and `$nin` that will be used to filter metadata. We call these +operators `InclusionExclusionOperator`. + +We suggest the following new operator definition: + +```python +InclusionExclusionOperator = Union[Literal["$in"], Literal["$nin"]] +``` + +Additionally, we suggest that those operators are added to `OperatorExpression` for seamless integration with +existing `Where` semantics: + +```python +OperatorExpression = Union[ + Dict[Union[WhereOperator, LogicalOperator], LiteralValue], + Dict[InclusionExclusionOperator, List[LiteralValue]], +] +``` + +An example of a query using the new operators would be: + +```python +collection.query(query_texts=query, + where={"$and": [{"author": {'$in': ['john', 'jill']}}, {"article_type": {"$eq": "blog"}}]}, + n_results=3) +``` + +## **Compatibility, Deprecation, and Migration Plan** + +The change is compatible with existing release 0.4.x. + +## **Test Plan** + +Property tests will be updated to ensure boundary conditions are covered as well as interoperability with existing `Where` +operators. + +## **Rejected Alternatives** + +N/A diff --git a/examples/basic_functionality/in_not_in_filtering.ipynb b/examples/basic_functionality/in_not_in_filtering.ipynb new file mode 100644 index 00000000000..3076d4a3585 --- /dev/null +++ b/examples/basic_functionality/in_not_in_filtering.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2023-08-30T12:48:38.227653Z", + "start_time": "2023-08-30T12:48:27.744069Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Number of requested results 10 is greater than number of elements in index 3, updating n_results = 3\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'ids': [['1', '3']], 'distances': [[0.28824201226234436, 1.017508625984192]], 'metadatas': [[{'author': 'john'}, {'author': 'jill'}]], 'embeddings': None, 'documents': [['Article by john', 'Article by Jill']]}\n", + "{'ids': ['1', '3'], 'embeddings': None, 'metadatas': [{'author': 'john'}, {'author': 'jill'}], 'documents': ['Article by john', 'Article by Jill']}\n" + ] + } + ], + "source": [ + "import chromadb\n", + "\n", + "from chromadb.utils import embedding_functions\n", + "\n", + "sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=\"all-MiniLM-L6-v2\")\n", + "\n", + "\n", + "client = chromadb.Client()\n", + "# client.heartbeat()\n", + "# client.reset()\n", + "collection = client.get_or_create_collection(\"test-where-list\", embedding_function=sentence_transformer_ef)\n", + "collection.add(documents=[\"Article by john\", \"Article by Jack\", \"Article by Jill\"],\n", + " metadatas=[{\"author\": \"john\"}, {\"author\": \"jack\"}, {\"author\": \"jill\"}], ids=[\"1\", \"2\", \"3\"])\n", + "\n", + "query = [\"Give me articles by john\"]\n", + "res = collection.query(query_texts=query,where={'author': {'$in': ['john', 'jill']}}, n_results=10)\n", + "print(res)\n", + "\n", + "res_get = collection.get(where={'author': {'$in': ['john', 'jill']}})\n", + "print(res_get)\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Interactions with existing Where operators" + ], + "metadata": { + "collapsed": false + }, + "id": "752cef843ba2f900" + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "data": { + "text/plain": "{'ids': [['1']],\n 'distances': [[0.28824201226234436]],\n 'metadatas': [[{'article_type': 'blog', 'author': 'john'}]],\n 'embeddings': None,\n 'documents': [['Article by john']]}" + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "collection.upsert(documents=[\"Article by john\", \"Article by Jack\", \"Article by Jill\"],\n", + " metadatas=[{\"author\": \"john\",\"article_type\":\"blog\"}, {\"author\": \"jack\",\"article_type\":\"social\"}, {\"author\": \"jill\",\"article_type\":\"paper\"}], ids=[\"1\", \"2\", \"3\"])\n", + "\n", + "collection.query(query_texts=query,where={\"$and\":[{\"author\": {'$in': ['john', 'jill']}},{\"article_type\":{\"$eq\":\"blog\"}}]}, n_results=3)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-30T12:48:49.974353Z", + "start_time": "2023-08-30T12:48:49.938985Z" + } + }, + "id": "ca56cda318f9e94d" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "data": { + "text/plain": "{'ids': [['1', '3']],\n 'distances': [[0.28824201226234436, 1.017508625984192]],\n 'metadatas': [[{'article_type': 'blog', 'author': 'john'},\n {'article_type': 'paper', 'author': 'jill'}]],\n 'embeddings': None,\n 'documents': [['Article by john', 'Article by Jill']]}" + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "collection.query(query_texts=query,where={\"$or\":[{\"author\": {'$in': ['john']}},{\"article_type\":{\"$in\":[\"paper\"]}}]}, n_results=3)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-30T12:48:53.501431Z", + "start_time": "2023-08-30T12:48:53.481571Z" + } + }, + "id": "f10e79ec90c797c1" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "d97b8b6dd96261d0" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}