diff --git a/docs/_src/api/api/document_store.md b/docs/_src/api/api/document_store.md index a36aaa6fff..5de47ae202 100644 --- a/docs/_src/api/api/document_store.md +++ b/docs/_src/api/api/document_store.md @@ -338,7 +338,7 @@ to performance issues. Note that Elasticsearch limits the number of results to 1 #### get\_metadata\_values\_by\_key ```python - | get_metadata_values_by_key(key: str, query: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> List[dict] + | get_metadata_values_by_key(key: str, query: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> List[dict] ``` Get values associated with a metadata key. The output is in the format: @@ -348,7 +348,31 @@ Get values associated with a metadata key. The output is in the format: - `key`: the meta key name to get the values for. - `query`: narrow down the scope to documents matching the query string. -- `filters`: narrow down the scope to documents that match the given filters. +- `filters`: Narrow down the scope to documents that match the given filters. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` - `index`: Elasticsearch index where the meta values should be searched. If not supplied, self.index will be used. - `headers`: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) @@ -428,7 +452,7 @@ Update the metadata dictionary of a document by specifying its string id #### get\_document\_count ```python - | get_document_count(filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None, only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None) -> int + | get_document_count(filters: Optional[Dict[str, Any]] = None, index: Optional[str] = None, only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None) -> int ``` Return the number of documents in the document store. @@ -446,7 +470,7 @@ Return the number of labels in the document store #### get\_embedding\_count ```python - | get_embedding_count(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, headers: Optional[Dict[str, str]] = None) -> int + | get_embedding_count(index: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None) -> int ``` Return the count of embeddings in the document store. @@ -455,7 +479,7 @@ Return the count of embeddings in the document store. #### get\_all\_documents ```python - | get_all_documents(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, return_embedding: Optional[bool] = None, batch_size: int = 10_000, headers: Optional[Dict[str, str]] = None) -> List[Document] + | get_all_documents(index: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, return_embedding: Optional[bool] = None, batch_size: int = 10_000, headers: Optional[Dict[str, str]] = None) -> List[Document] ``` Get documents from the document store. @@ -465,7 +489,30 @@ Get documents from the document store. - `index`: Name of the index to get the documents from. If None, the DocumentStore's default index (self.index) will be used. - `filters`: Optional filters to narrow down the documents to return. - Example: {"name": ["some", "more"], "category": ["only_one"]} + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` - `return_embedding`: Whether to return the document embeddings. - `batch_size`: When working with large number of documents, batching can help reduce memory footprint. - `headers`: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) @@ -475,7 +522,7 @@ Get documents from the document store. #### get\_all\_documents\_generator ```python - | get_all_documents_generator(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, return_embedding: Optional[bool] = None, batch_size: int = 10_000, headers: Optional[Dict[str, str]] = None) -> Generator[Document, None, None] + | get_all_documents_generator(index: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, return_embedding: Optional[bool] = None, batch_size: int = 10_000, headers: Optional[Dict[str, str]] = None) -> Generator[Document, None, None] ``` Get documents from the document store. Under-the-hood, documents are fetched in batches from the @@ -487,7 +534,30 @@ a large number of documents without having to load all documents in memory. - `index`: Name of the index to get the documents from. If None, the DocumentStore's default index (self.index) will be used. - `filters`: Optional filters to narrow down the documents to return. - Example: {"name": ["some", "more"], "category": ["only_one"]} + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` - `return_embedding`: Whether to return the document embeddings. - `batch_size`: When working with large number of documents, batching can help reduce memory footprint. - `headers`: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) @@ -497,7 +567,7 @@ a large number of documents without having to load all documents in memory. #### get\_all\_labels ```python - | get_all_labels(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, headers: Optional[Dict[str, str]] = None, batch_size: int = 10_000) -> List[Label] + | get_all_labels(index: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, batch_size: int = 10_000) -> List[Label] ``` Return all labels in the document store @@ -506,7 +576,7 @@ Return all labels in the document store #### query ```python - | query(query: Optional[str], filters: Optional[Dict[str, List[str]]] = None, top_k: int = 10, custom_query: Optional[str] = None, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> List[Document] + | query(query: Optional[str], filters: Optional[Dict[str, Any]] = None, top_k: int = 10, custom_query: Optional[str] = None, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> List[Document] ``` Scan through documents in DocumentStore and return a small number documents @@ -515,7 +585,69 @@ that are most relevant to the query as defined by the BM25 algorithm. **Arguments**: - `query`: The query -- `filters`: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field +- `filters`: Optional filters to narrow down the search space to documents whose metadata fulfill certain + conditions. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + Example: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` - `top_k`: How many documents to return per query. - `custom_query`: query string as per Elasticsearch DSL with a mandatory query placeholder(query). @@ -591,7 +723,7 @@ that are most relevant to the query as defined by the BM25 algorithm. #### query\_by\_embedding ```python - | query_by_embedding(query_emb: np.ndarray, filters: Optional[Dict[str, List[str]]] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None, headers: Optional[Dict[str, str]] = None) -> List[Document] + | query_by_embedding(query_emb: np.ndarray, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None, headers: Optional[Dict[str, str]] = None) -> List[Document] ``` Find the document that is most similar to the provided `query_emb` by using a vector similarity metric. @@ -599,8 +731,69 @@ Find the document that is most similar to the provided `query_emb` by using a ve **Arguments**: - `query_emb`: Embedding of the query (e.g. gathered from DPR) -- `filters`: Optional filters to narrow down the search space. - Example: {"name": ["some", "more"], "category": ["only_one"]} +- `filters`: Optional filters to narrow down the search space to documents whose metadata fulfill certain + conditions. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + Example: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` - `top_k`: How many documents to return - `index`: Index name for storing the docs and metadata - `return_embedding`: To return document embedding @@ -624,7 +817,7 @@ Return a summary of the documents in the document store #### update\_embeddings ```python - | update_embeddings(retriever, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, update_existing_embeddings: bool = True, batch_size: int = 10_000, headers: Optional[Dict[str, str]] = None) + | update_embeddings(retriever, index: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, update_existing_embeddings: bool = True, batch_size: int = 10_000, headers: Optional[Dict[str, str]] = None) ``` Updates the embeddings in the the document store using the encoding model specified in the retriever. @@ -639,7 +832,30 @@ This can be useful if want to add or change the embeddings for your documents (e incremental updating of embeddings, wherein, only newly indexed documents get processed. - `filters`: Optional filters to narrow down the documents for which embeddings are to be updated. - Example: {"name": ["some", "more"], "category": ["only_one"]} + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` - `batch_size`: When working with large number of documents, batching can help reduce memory footprint. - `headers`: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. @@ -652,7 +868,7 @@ None #### delete\_all\_documents ```python - | delete_all_documents(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, headers: Optional[Dict[str, str]] = None) + | delete_all_documents(index: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None) ``` Delete documents in an index. All documents are deleted if no filters are passed. @@ -661,6 +877,30 @@ Delete documents in an index. All documents are deleted if no filters are passed - `index`: Index name to delete the document from. - `filters`: Optional filters to narrow down the documents to be deleted. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` - `headers`: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. @@ -672,7 +912,7 @@ None #### delete\_documents ```python - | delete_documents(index: Optional[str] = None, ids: Optional[List[str]] = None, filters: Optional[Dict[str, List[str]]] = None, headers: Optional[Dict[str, str]] = None) + | delete_documents(index: Optional[str] = None, ids: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None) ``` Delete documents in an index. All documents are deleted if no filters are passed. @@ -683,10 +923,34 @@ Delete documents in an index. All documents are deleted if no filters are passed DocumentStore's default index (self.index) will be used - `ids`: Optional list of IDs to narrow down the documents to be deleted. - `filters`: Optional filters to narrow down the documents to be deleted. - Example filters: {"name": ["some", "more"], "category": ["only_one"]}. - If filters are provided along with a list of IDs, this method deletes the - intersection of the two query results (documents that match the filters and - have their ID in the list). + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` + + If filters are provided along with a list of IDs, this method deletes the + intersection of the two query results (documents that match the filters and + have their ID in the list). - `headers`: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. @@ -698,7 +962,7 @@ None #### delete\_labels ```python - | delete_labels(index: Optional[str] = None, ids: Optional[List[str]] = None, filters: Optional[Dict[str, List[str]]] = None, headers: Optional[Dict[str, str]] = None) + | delete_labels(index: Optional[str] = None, ids: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None) ``` Delete labels in an index. All labels are deleted if no filters are passed. @@ -709,7 +973,30 @@ Delete labels in an index. All labels are deleted if no filters are passed. DocumentStore's default label index (self.label_index) will be used - `ids`: Optional list of IDs to narrow down the labels to be deleted. - `filters`: Optional filters to narrow down the labels to be deleted. - Example filters: {"id": ["9a196e41-f7b5-45b4-bd19-5feb7501c159", "9a196e41-f7b5-45b4-bd19-5feb7501c159"]} or {"query": ["question2"]} + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` - `headers`: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. @@ -750,7 +1037,7 @@ the KNN plugin that can scale to a large number of documents. #### query\_by\_embedding ```python - | query_by_embedding(query_emb: np.ndarray, filters: Optional[Dict[str, List[str]]] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None, headers: Optional[Dict[str, str]] = None) -> List[Document] + | query_by_embedding(query_emb: np.ndarray, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None, headers: Optional[Dict[str, str]] = None) -> List[Document] ``` Find the document that is most similar to the provided `query_emb` by using a vector similarity metric. @@ -758,8 +1045,69 @@ Find the document that is most similar to the provided `query_emb` by using a ve **Arguments**: - `query_emb`: Embedding of the query (e.g. gathered from DPR) -- `filters`: Optional filters to narrow down the search space. - Example: {"name": ["some", "more"], "category": ["only_one"]} +- `filters`: Optional filters to narrow down the search space to documents whose metadata fulfill certain + conditions. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + Example: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` - `top_k`: How many documents to return - `index`: Index name for storing the docs and metadata - `return_embedding`: To return document embedding diff --git a/haystack/document_stores/elasticsearch.py b/haystack/document_stores/elasticsearch.py index 4e7408f7ee..7240b63215 100644 --- a/haystack/document_stores/elasticsearch.py +++ b/haystack/document_stores/elasticsearch.py @@ -6,10 +6,11 @@ import time from copy import deepcopy from string import Template +from collections import defaultdict + import numpy as np from scipy.special import expit from tqdm.auto import tqdm -import pandas as pd try: from elasticsearch import Elasticsearch, RequestsHttpConnection @@ -23,6 +24,7 @@ from haystack.document_stores import KeywordDocumentStore from haystack.schema import Document, Label from haystack.document_stores.base import get_batches_from_generator +from haystack.document_stores.filter_utils import LogicalFilterClause logger = logging.getLogger(__name__) @@ -476,7 +478,7 @@ def get_metadata_values_by_key( self, key: str, query: Optional[str] = None, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None, ) -> List[dict]: @@ -486,7 +488,31 @@ def get_metadata_values_by_key( :param key: the meta key name to get the values for. :param query: narrow down the scope to documents matching the query string. - :param filters: narrow down the scope to documents that match the given filters. + :param filters: Narrow down the scope to documents that match the given filters. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` :param index: Elasticsearch index where the meta values should be searched. If not supplied, self.index will be used. :param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) @@ -508,12 +534,9 @@ def get_metadata_values_by_key( } } if filters: - filter_clause = [] - for key, values in filters.items(): - filter_clause.append({"terms": {key: values}}) if not body.get("query"): body["query"] = {"bool": {}} - body["query"]["bool"].update({"filter": filter_clause}) + body["query"]["bool"].update({"filter": LogicalFilterClause.parse(filters).convert_to_elasticsearch()}) result = self.client.search(body=body, index=index, headers=headers) buckets = result["aggregations"]["metadata_agg"]["buckets"] for bucket in buckets: @@ -630,8 +653,8 @@ def write_labels( if index and not self.client.indices.exists(index=index, headers=headers): self._create_label_index(index, headers=headers) - labels = [Label.from_dict(label) if isinstance(label, dict) else label for label in labels] - duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)] + label_list: List[Label] = [Label.from_dict(label) if isinstance(label, dict) else label for label in labels] + duplicate_ids: list = [label.id for label in self._get_duplicate_labels(label_list, index=index)] if len(duplicate_ids) > 0: logger.warning( f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store." @@ -640,7 +663,7 @@ def write_labels( f" Problematic ids: {','.join(duplicate_ids)}" ) labels_to_index = [] - for label in labels: + for label in label_list: # create timestamps if not available yet if not label.created_at: # type: ignore label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") # type: ignore @@ -649,7 +672,7 @@ def write_labels( _label = { "_op_type": "index" - if self.duplicate_documents == "overwrite" or label.id in duplicate_ids # type: ignore + if self.duplicate_documents == "overwrite" or label.id in duplicate_ids else "create", # type: ignore "_index": index, **label.to_dict(), # type: ignore @@ -682,7 +705,7 @@ def update_document_meta( def get_document_count( self, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, index: Optional[str] = None, only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None, @@ -697,15 +720,7 @@ def get_document_count( body["query"]["bool"]["must_not"] = [{"exists": {"field": self.embedding_field}}] if filters: - filter_clause = [] - for key, values in filters.items(): - if type(values) != list: - raise ValueError( - f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. ' - 'Example: {"name": ["some", "more"], "category": ["only_one"]} ' - ) - filter_clause.append({"terms": {key: values}}) - body["query"]["bool"]["filter"] = filter_clause + body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch() result = self.client.count(index=index, body=body, headers=headers) count = result["count"] @@ -721,7 +736,7 @@ def get_label_count(self, index: Optional[str] = None, headers: Optional[Dict[st def get_embedding_count( self, index: Optional[str] = None, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ) -> int: """ @@ -732,15 +747,7 @@ def get_embedding_count( body: dict = {"query": {"bool": {"must": [{"exists": {"field": self.embedding_field}}]}}} if filters: - filter_clause = [] - for key, values in filters.items(): - if type(values) != list: - raise ValueError( - f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. ' - 'Example: {"name": ["some", "more"], "category": ["only_one"]} ' - ) - filter_clause.append({"terms": {key: values}}) - body["query"]["bool"]["filter"] = filter_clause + body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch() result = self.client.count(index=index, body=body, headers=headers) count = result["count"] @@ -749,7 +756,7 @@ def get_embedding_count( def get_all_documents( self, index: Optional[str] = None, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, return_embedding: Optional[bool] = None, batch_size: int = 10_000, headers: Optional[Dict[str, str]] = None, @@ -760,7 +767,30 @@ def get_all_documents( :param index: Name of the index to get the documents from. If None, the DocumentStore's default index (self.index) will be used. :param filters: Optional filters to narrow down the documents to return. - Example: {"name": ["some", "more"], "category": ["only_one"]} + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` :param return_embedding: Whether to return the document embeddings. :param batch_size: When working with large number of documents, batching can help reduce memory footprint. :param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) @@ -775,7 +805,7 @@ def get_all_documents( def get_all_documents_generator( self, index: Optional[str] = None, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, return_embedding: Optional[bool] = None, batch_size: int = 10_000, headers: Optional[Dict[str, str]] = None, @@ -788,7 +818,30 @@ def get_all_documents_generator( :param index: Name of the index to get the documents from. If None, the DocumentStore's default index (self.index) will be used. :param filters: Optional filters to narrow down the documents to return. - Example: {"name": ["some", "more"], "category": ["only_one"]} + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` :param return_embedding: Whether to return the document embeddings. :param batch_size: When working with large number of documents, batching can help reduce memory footprint. :param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) @@ -809,7 +862,7 @@ def get_all_documents_generator( def get_all_labels( self, index: Optional[str] = None, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, batch_size: int = 10_000, ) -> List[Label]: @@ -826,7 +879,7 @@ def get_all_labels( def _get_all_documents_in_index( self, index: str, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, batch_size: int = 10_000, only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None, @@ -837,10 +890,7 @@ def _get_all_documents_in_index( body: dict = {"query": {"bool": {}}} if filters: - filter_clause = [] - for key, values in filters.items(): - filter_clause.append({"terms": {key: values}}) - body["query"]["bool"]["filter"] = filter_clause + body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch() if only_documents_without_embedding: body["query"]["bool"]["must_not"] = [{"exists": {"field": self.embedding_field}}] @@ -851,7 +901,7 @@ def _get_all_documents_in_index( def query( self, query: Optional[str], - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, top_k: int = 10, custom_query: Optional[str] = None, index: Optional[str] = None, @@ -862,7 +912,69 @@ def query( that are most relevant to the query as defined by the BM25 algorithm. :param query: The query - :param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field + :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain + conditions. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + Example: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` :param top_k: How many documents to return per query. :param custom_query: query string as per Elasticsearch DSL with a mandatory query placeholder(query). @@ -942,10 +1054,7 @@ def query( if query is None: body = {"query": {"bool": {"must": {"match_all": {}}}}} # type: Dict[str, Any] if filters: - filter_clause = [] - for key, values in filters.items(): - filter_clause.append({"terms": {key: values}}) - body["query"]["bool"]["filter"] = filter_clause + body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch() # Retrieval via custom query elif custom_query: # substitute placeholder for query and filters for the custom_query template string @@ -982,15 +1091,7 @@ def query( } if filters: - filter_clause = [] - for key, values in filters.items(): - if type(values) != list: - raise ValueError( - f'Wrong filter format: "{key}": {values}. Provide a list of values for each key. ' - 'Example: {"name": ["some", "more"], "category": ["only_one"]} ' - ) - filter_clause.append({"terms": {key: values}}) - body["query"]["bool"]["filter"] = filter_clause + body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch() if self.excluded_meta_data: body["_source"] = {"excludes": self.excluded_meta_data} @@ -1004,7 +1105,7 @@ def query( def query_by_embedding( self, query_emb: np.ndarray, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None, @@ -1014,8 +1115,69 @@ def query_by_embedding( Find the document that is most similar to the provided `query_emb` by using a vector similarity metric. :param query_emb: Embedding of the query (e.g. gathered from DPR) - :param filters: Optional filters to narrow down the search space. - Example: {"name": ["some", "more"], "category": ["only_one"]} + :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain + conditions. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + Example: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` :param top_k: How many documents to return :param index: Index name for storing the docs and metadata :param return_embedding: To return document embedding @@ -1035,15 +1197,9 @@ def query_by_embedding( # +1 in similarity to avoid negative numbers (for cosine sim) body = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)} if filters: - filter_clause = [] - for key, values in filters.items(): - if type(values) != list: - raise ValueError( - f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. ' - 'Example: {"name": ["some", "more"], "category": ["only_one"]} ' - ) - filter_clause.append({"terms": {key: values}}) - body["query"]["script_score"]["query"] = {"bool": {"filter": filter_clause}} + body["query"]["script_score"]["query"] = { + "bool": {"filter": LogicalFilterClause.parse(filters).convert_to_elasticsearch()} + } excluded_meta_data: Optional[list] = None @@ -1193,7 +1349,7 @@ def update_embeddings( self, retriever, index: Optional[str] = None, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, update_existing_embeddings: bool = True, batch_size: int = 10_000, headers: Optional[Dict[str, str]] = None, @@ -1209,7 +1365,30 @@ def update_embeddings( incremental updating of embeddings, wherein, only newly indexed documents get processed. :param filters: Optional filters to narrow down the documents for which embeddings are to be updated. - Example: {"name": ["some", "more"], "category": ["only_one"]} + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` :param batch_size: When working with large number of documents, batching can help reduce memory footprint. :param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. @@ -1271,7 +1450,7 @@ def update_embeddings( def delete_all_documents( self, index: Optional[str] = None, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ): """ @@ -1279,6 +1458,30 @@ def delete_all_documents( :param index: Index name to delete the document from. :param filters: Optional filters to narrow down the documents to be deleted. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` :param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. :return: None @@ -1295,7 +1498,7 @@ def delete_documents( self, index: Optional[str] = None, ids: Optional[List[str]] = None, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ): """ @@ -1305,10 +1508,34 @@ def delete_documents( DocumentStore's default index (self.index) will be used :param ids: Optional list of IDs to narrow down the documents to be deleted. :param filters: Optional filters to narrow down the documents to be deleted. - Example filters: {"name": ["some", "more"], "category": ["only_one"]}. - If filters are provided along with a list of IDs, this method deletes the - intersection of the two query results (documents that match the filters and - have their ID in the list). + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` + + If filters are provided along with a list of IDs, this method deletes the + intersection of the two query results (documents that match the filters and + have their ID in the list). :param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. :return: None @@ -1316,10 +1543,7 @@ def delete_documents( index = index or self.index query: Dict[str, Any] = {"query": {}} if filters: - filter_clause = [] - for key, values in filters.items(): - filter_clause.append({"terms": {key: values}}) - query["query"]["bool"] = {"filter": filter_clause} + query["query"]["bool"] = {"filter": LogicalFilterClause.parse(filters).convert_to_elasticsearch()} if ids: query["query"]["bool"]["must"] = {"ids": {"values": ids}} @@ -1337,7 +1561,7 @@ def delete_labels( self, index: Optional[str] = None, ids: Optional[List[str]] = None, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ): """ @@ -1347,7 +1571,30 @@ def delete_labels( DocumentStore's default label index (self.label_index) will be used :param ids: Optional list of IDs to narrow down the labels to be deleted. :param filters: Optional filters to narrow down the labels to be deleted. - Example filters: {"id": ["9a196e41-f7b5-45b4-bd19-5feb7501c159", "9a196e41-f7b5-45b4-bd19-5feb7501c159"]} or {"query": ["question2"]} + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` :param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. :return: None @@ -1386,7 +1633,7 @@ def __init__(self, verify_certs=False, scheme="https", username="admin", passwor def query_by_embedding( self, query_emb: np.ndarray, - filters: Optional[Dict[str, List[str]]] = None, + filters: Optional[Dict[str, Any]] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None, @@ -1396,8 +1643,69 @@ def query_by_embedding( Find the document that is most similar to the provided `query_emb` by using a vector similarity metric. :param query_emb: Embedding of the query (e.g. gathered from DPR) - :param filters: Optional filters to narrow down the search space. - Example: {"name": ["some", "more"], "category": ["only_one"]} + :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain + conditions. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + Example: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` :param top_k: How many documents to return :param index: Index name for storing the docs and metadata :param return_embedding: To return document embedding @@ -1415,17 +1723,12 @@ def query_by_embedding( raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()") else: # +1 in similarity to avoid negative numbers (for cosine sim) - body = {"size": top_k, "query": {"bool": {"must": [self._get_vector_similarity_query(query_emb, top_k)]}}} + body: Dict[str, Any] = { + "size": top_k, + "query": {"bool": {"must": [self._get_vector_similarity_query(query_emb, top_k)]}}, + } if filters: - filter_clause = [] - for key, values in filters.items(): - if type(values) != list: - raise ValueError( - f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. ' - 'Example: {"name": ["some", "more"], "category": ["only_one"]} ' - ) - filter_clause.append({"terms": {key: values}}) - body["query"]["bool"]["filter"] = filter_clause # type: ignore + body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch() excluded_meta_data: Optional[list] = None diff --git a/haystack/document_stores/filter_utils.py b/haystack/document_stores/filter_utils.py new file mode 100644 index 0000000000..a45b98734c --- /dev/null +++ b/haystack/document_stores/filter_utils.py @@ -0,0 +1,293 @@ +from typing import Union, List, Dict +from abc import ABC, abstractmethod +from collections import defaultdict + + +def nested_defaultdict(): + """ + Data structure that recursively adds a dictionary as value if a key does not exist. Advantage: In nested dictionary + structures, we don't need to check if a key already exists (which can become hard to maintain in nested dictionaries + with many levels) but access the existing value if a key exists and create an empty dictionary if a key does not + exist. + """ + return defaultdict(nested_defaultdict) + + +class LogicalFilterClause(ABC): + """ + Class that is able to parse a filter and convert it to the format that the underlying databases of our + DocumentStores require. + + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, `"$gte"`, `"$lt"`, + `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + Example: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take optionally a list of + dictionaries as value. + + Example: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` + + """ + + def __init__(self, conditions: List["LogicalFilterClause"]): + self.conditions = conditions + + @classmethod + def parse(cls, filter_term: Union[dict, List[dict]]): + """ + Parses a filter dictionary/list and returns a LogicalFilterClause instance. + + :param filter_term: Dictionary or list that contains the filter definition. + """ + conditions = [] + + if isinstance(filter_term, dict): + filter_term = [filter_term] + for item in filter_term: + for key, value in item.items(): + if key == "$not": + conditions.append(NotOperation.parse(value)) + elif key == "$and": + conditions.append(AndOperation.parse(value)) + elif key == "$or": + conditions.append(OrOperation.parse(value)) + # Key needs to be a metadata field + else: + conditions.extend(ComparisonOperation.parse(key, value)) + + if cls == LogicalFilterClause: + if len(conditions) == 1: + return conditions[0] + else: + return AndOperation(conditions) + else: + return cls(conditions) + + @abstractmethod + def convert_to_elasticsearch(self): + """ + Converts the LogicalFilterClause instance to an Elasticsearch filter. + """ + pass + + def _merge_es_range_queries(self, conditions: List[Dict]) -> List[Dict]: + """ + Merges Elasticsearch range queries that perform on the same metadata field. + """ + + range_conditions = [cond["range"] for cond in filter(lambda condition: "range" in condition, conditions)] + if range_conditions: + conditions = [condition for condition in conditions if "range" not in condition] + range_conditions_dict = nested_defaultdict() + for condition in range_conditions: + field_name = list(condition.keys())[0] + operation = list(condition[field_name].keys())[0] + comparison_value = condition[field_name][operation] + range_conditions_dict[field_name][operation] = comparison_value + + for field_name, comparison_operations in range_conditions_dict.items(): + conditions.append({"range": {field_name: comparison_operations}}) + + return conditions + + +class ComparisonOperation(ABC): + def __init__(self, field_name: str, comparison_value: Union[str, float, List]): + self.field_name = field_name + self.comparison_value = comparison_value + + @classmethod + def parse(cls, field_name, comparison_clause: Union[Dict, List, str, float]): + comparison_operations: List[ComparisonOperation] = [] + + if isinstance(comparison_clause, dict): + for comparison_operation, comparison_value in comparison_clause.items(): + if comparison_operation == "$eq": + comparison_operations.append(EqOperation(field_name, comparison_value)) + elif comparison_operation == "$in": + comparison_operations.append(InOperation(field_name, comparison_value)) + elif comparison_operation == "$ne": + comparison_operations.append(NeOperation(field_name, comparison_value)) + elif comparison_operation == "$nin": + comparison_operations.append(NinOperation(field_name, comparison_value)) + elif comparison_operation == "$gt": + comparison_operations.append(GtOperation(field_name, comparison_value)) + elif comparison_operation == "$gte": + comparison_operations.append(GteOperation(field_name, comparison_value)) + elif comparison_operation == "$lt": + comparison_operations.append(LtOperation(field_name, comparison_value)) + elif comparison_operation == "$lte": + comparison_operations.append(LteOperation(field_name, comparison_value)) + + # No comparison operator is given, so we use the default operators "$in" if the comparison value is a list and + # "$eq" in every other case + elif isinstance(comparison_clause, list): + comparison_operations.append(InOperation(field_name, comparison_clause)) + else: + comparison_operations.append((EqOperation(field_name, comparison_clause))) + + return comparison_operations + + @abstractmethod + def convert_to_elasticsearch(self): + """ + Converts the ComparisonOperation instance to an Elasticsearch query. + """ + pass + + +class NotOperation(LogicalFilterClause): + """ + Handles conversion of logical 'NOT' operations. + """ + + def convert_to_elasticsearch(self): + conditions = [condition.convert_to_elasticsearch() for condition in self.conditions] + conditions = self._merge_es_range_queries(conditions) + return {"bool": {"must_not": conditions}} + + +class AndOperation(LogicalFilterClause): + """ + Handles conversion of logical 'AND' operations. + """ + + def convert_to_elasticsearch(self): + conditions = [condition.convert_to_elasticsearch() for condition in self.conditions] + conditions = self._merge_es_range_queries(conditions) + return {"bool": {"must": conditions}} + + +class OrOperation(LogicalFilterClause): + """ + Handles conversion of logical 'OR' operations. + """ + + def convert_to_elasticsearch(self): + conditions = [condition.convert_to_elasticsearch() for condition in self.conditions] + conditions = self._merge_es_range_queries(conditions) + return {"bool": {"should": conditions}} + + +class EqOperation(ComparisonOperation): + """ + Handles conversion of the '$eq' comparison operation. + """ + + def convert_to_elasticsearch(self): + return {"term": {self.field_name: self.comparison_value}} + + +class InOperation(ComparisonOperation): + """ + Handles conversion of the '$in' comparison operation. + """ + + def convert_to_elasticsearch(self): + return {"terms": {self.field_name: self.comparison_value}} + + +class NeOperation(ComparisonOperation): + """ + Handles conversion of the '$ne' comparison operation. + """ + + def convert_to_elasticsearch(self): + return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}} + + +class NinOperation(ComparisonOperation): + """ + Handles conversion of the '$nin' comparison operation. + """ + + def convert_to_elasticsearch(self): + return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}} + + +class GtOperation(ComparisonOperation): + """ + Handles conversion of the '$gt' comparison operation. + """ + + def convert_to_elasticsearch(self): + return {"range": {self.field_name: {"gt": self.comparison_value}}} + + +class GteOperation(ComparisonOperation): + """ + Handles conversion of the '$gte' comparison operation. + """ + + def convert_to_elasticsearch(self): + return {"range": {self.field_name: {"gte": self.comparison_value}}} + + +class LtOperation(ComparisonOperation): + """ + Handles conversion of the '$lt' comparison operation. + """ + + def convert_to_elasticsearch(self): + return {"range": {self.field_name: {"lt": self.comparison_value}}} + + +class LteOperation(ComparisonOperation): + """ + Handles conversion of the '$lte' comparison operation. + """ + + def convert_to_elasticsearch(self): + return {"range": {self.field_name: {"lte": self.comparison_value}}} diff --git a/test/conftest.py b/test/conftest.py index 7b548e397a..631d4163ca 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -144,14 +144,6 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_docstore) -@pytest.fixture -def tmpdir(tmpdir): - """ - Makes pytest's tmpdir fixture fully compatible with pathlib - """ - return Path(tmpdir) - - @pytest.fixture(scope="function", autouse=True) def gc_cleanup(request): """ @@ -344,12 +336,30 @@ def de_to_en_translator(): def test_docs_xs(): return [ # current "dict" format for a document - {"content": "My name is Carla and I live in Berlin", "meta": {"meta_field": "test1", "name": "filename1"}}, + { + "content": "My name is Carla and I live in Berlin", + "meta": {"meta_field": "test1", "name": "filename1", "date_field": "2020-03-01", "numeric_field": 5.5}, + }, # metafield at the top level for backward compatibility - {"content": "My name is Paul and I live in New York", "meta_field": "test2", "name": "filename2"}, + { + "content": "My name is Paul and I live in New York", + "meta_field": "test2", + "name": "filename2", + "date_field": "2019-10-01", + "numeric_field": 5, + }, # Document object for a doc Document( - content="My name is Christelle and I live in Paris", meta={"meta_field": "test3", "name": "filename3"} + content="My name is Christelle and I live in Paris", + meta={"meta_field": "test3", "name": "filename3", "date_field": "2018-10-01", "numeric_field": 4.5}, + ), + Document( + content="My name is Camila and I live in Madrid", + meta={"meta_field": "test4", "name": "filename4", "date_field": "2021-02-01", "numeric_field": 3}, + ), + Document( + content="My name is Matteo and I live in Rome", + meta={"meta_field": "test5", "name": "filename5", "date_field": "2019-01-01", "numeric_field": 0}, ), ] @@ -543,6 +553,16 @@ def document_store_with_docs(request, test_docs_xs, tmp_path): document_store = get_document_store( document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path ) + # TODO: remove the following part once we allow numbers as metadatfield value in WeaviateDocumentStore + if request.param == "weaviate": + for doc in test_docs_xs: + if isinstance(doc, Document): + doc.meta["numeric_field"] = str(doc.meta["numeric_field"]) + else: + if "meta" in doc: + doc["meta"]["numeric_field"] = str(doc["meta"]["numeric_field"]) + else: + doc["numeric_field"] = str(doc["numeric_field"]) document_store.write_documents(test_docs_xs) yield document_store document_store.delete_documents() diff --git a/test/test_document_store.py b/test/test_document_store.py index 1c978f41db..46f8a74782 100644 --- a/test/test_document_store.py +++ b/test/test_document_store.py @@ -151,9 +151,9 @@ def test_write_with_duplicate_doc_ids_custom_index(document_store): def test_get_all_documents_without_filters(document_store_with_docs): documents = document_store_with_docs.get_all_documents() assert all(isinstance(d, Document) for d in documents) - assert len(documents) == 3 - assert {d.meta["name"] for d in documents} == {"filename1", "filename2", "filename3"} - assert {d.meta["meta_field"] for d in documents} == {"test1", "test2", "test3"} + assert len(documents) == 5 + assert {d.meta["name"] for d in documents} == {"filename1", "filename2", "filename3", "filename4", "filename5"} + assert {d.meta["meta_field"] for d in documents} == {"test1", "test2", "test3", "test4", "test5"} def test_get_all_document_filter_duplicate_text_value(document_store): @@ -215,6 +215,107 @@ def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs) assert len(documents) == 0 +@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) +def test_extended_filter(document_store_with_docs): + # Test comparison operators individually + documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "test1"}}) + assert len(documents) == 1 + documents = document_store_with_docs.get_all_documents(filters={"meta_field": "test1"}) + assert len(documents) == 1 + + documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$in": ["test1", "test2", "n.a."]}}) + assert len(documents) == 2 + documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test2", "n.a."]}) + assert len(documents) == 2 + + documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$ne": "test1"}}) + assert len(documents) == 4 + + documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$nin": ["test1", "test2", "n.a."]}}) + assert len(documents) == 3 + + documents = document_store_with_docs.get_all_documents(filters={"numeric_field": {"$gt": 3}}) + assert len(documents) == 3 + + documents = document_store_with_docs.get_all_documents(filters={"numeric_field": {"$gte": 3}}) + assert len(documents) == 4 + + documents = document_store_with_docs.get_all_documents(filters={"numeric_field": {"$lt": 3}}) + assert len(documents) == 1 + + documents = document_store_with_docs.get_all_documents(filters={"numeric_field": {"$lte": 3}}) + assert len(documents) == 2 + + # Test compound filters + filters = {"date_field": {"$lte": "2020-12-31", "$gte": "2019-01-01"}} + documents = document_store_with_docs.get_all_documents(filters=filters) + assert len(documents) == 3 + + filters = { + "$and": { + "date_field": {"$lte": "2020-12-31", "$gte": "2019-01-01"}, + "name": {"$in": ["filename5", "filename3"]}, + } + } + documents = document_store_with_docs.get_all_documents(filters=filters) + assert len(documents) == 1 + filters_simplified = { + "date_field": {"$lte": "2020-12-31", "$gte": "2019-01-01"}, + "name": ["filename5", "filename3"], + } + documents_simplified_filter = document_store_with_docs.get_all_documents(filters=filters_simplified) + assert documents == documents_simplified_filter + + filters = { + "$and": { + "date_field": {"$lte": "2020-12-31", "$gte": "2019-01-01"}, + "$or": {"name": {"$in": ["filename5", "filename3"]}, "numeric_field": {"$lte": 5}}, + } + } + documents = document_store_with_docs.get_all_documents(filters=filters) + assert len(documents) == 2 + filters_simplified = { + "date_field": {"$lte": "2020-12-31", "$gte": "2019-01-01"}, + "$or": {"name": ["filename5", "filename3"], "numeric_field": {"$lte": 5}}, + } + documents_simplified_filter = document_store_with_docs.get_all_documents(filters=filters_simplified) + assert documents == documents_simplified_filter + + filters = { + "$and": { + "date_field": {"$lte": "2020-12-31", "$gte": "2019-01-01"}, + "$or": { + "name": {"$in": ["filename5", "filename3"]}, + "$and": {"numeric_field": {"$lte": 5}, "$not": {"meta_field": {"$eq": "test2"}}}, + }, + } + } + documents = document_store_with_docs.get_all_documents(filters=filters) + assert len(documents) == 1 + filters_simplified = { + "date_field": {"$lte": "2020-12-31", "$gte": "2019-01-01"}, + "$or": { + "name": ["filename5", "filename3"], + "$and": {"numeric_field": {"$lte": 5}, "$not": {"meta_field": "test2"}}, + }, + } + documents_simplified_filter = document_store_with_docs.get_all_documents(filters=filters_simplified) + assert documents == documents_simplified_filter + + # Test same logical operator twice on same level + filters = { + "$or": [ + {"$and": {"meta_field": {"$in": ["test1", "test2"]}, "date_field": {"$gte": "2020-01-01"}}}, + {"$and": {"meta_field": {"$in": ["test3", "test4"]}, "date_field": {"$lt": "2020-01-01"}}}, + ] + } + documents = document_store_with_docs.get_all_documents(filters=filters) + docs_meta = [doc.meta["meta_field"] for doc in documents] + assert len(documents) == 2 + assert "test1" in docs_meta + assert "test3" in docs_meta + + def test_get_document_by_id(document_store_with_docs): documents = document_store_with_docs.get_all_documents() doc = document_store_with_docs.get_document_by_id(documents[0].id) @@ -543,7 +644,7 @@ def test_update_embeddings_table_text_retriever(document_store, retriever): def test_delete_all_documents(document_store_with_docs): - assert len(document_store_with_docs.get_all_documents()) == 3 + assert len(document_store_with_docs.get_all_documents()) == 5 document_store_with_docs.delete_documents() documents = document_store_with_docs.get_all_documents() @@ -551,7 +652,7 @@ def test_delete_all_documents(document_store_with_docs): def test_delete_documents(document_store_with_docs): - assert len(document_store_with_docs.get_all_documents()) == 3 + assert len(document_store_with_docs.get_all_documents()) == 5 document_store_with_docs.delete_documents() documents = document_store_with_docs.get_all_documents() @@ -559,14 +660,16 @@ def test_delete_documents(document_store_with_docs): def test_delete_documents_with_filters(document_store_with_docs): - document_store_with_docs.delete_documents(filters={"meta_field": ["test1", "test2"]}) + document_store_with_docs.delete_documents(filters={"meta_field": ["test1", "test2", "test4", "test5"]}) documents = document_store_with_docs.get_all_documents() assert len(documents) == 1 assert documents[0].meta["meta_field"] == "test3" def test_delete_documents_by_id(document_store_with_docs): - docs_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test2"]}) + docs_to_delete = document_store_with_docs.get_all_documents( + filters={"meta_field": ["test1", "test2", "test4", "test5"]} + ) docs_not_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test3"]}) document_store_with_docs.delete_documents(ids=[doc.id for doc in docs_to_delete]) @@ -585,7 +688,7 @@ def test_delete_documents_by_id_with_filters(document_store_with_docs): document_store_with_docs.delete_documents(ids=[doc.id for doc in docs_to_delete], filters={"meta_field": ["test1"]}) all_docs_left = document_store_with_docs.get_all_documents() - assert len(all_docs_left) == 2 + assert len(all_docs_left) == 4 assert all(doc.meta["meta_field"] != "test1" for doc in all_docs_left) all_ids_left = [doc.id for doc in all_docs_left] @@ -1074,7 +1177,9 @@ def test_similarity_score(document_store_with_docs): pipeline = DocumentSearchPipeline(retriever) prediction = pipeline.run("Paul lives in New York") scores = [document.score for document in prediction["documents"]] - assert scores == pytest.approx([0.9102500000000191, 0.6491700000000264, 0.6321699999999737], abs=1e-3) + assert scores == pytest.approx( + [0.9102507941407827, 0.6937791467877008, 0.6491682889305038, 0.6321622491318529, 0.5909129441370939], abs=1e-3 + ) @pytest.mark.parametrize( @@ -1090,7 +1195,9 @@ def test_similarity_score_dot_product(document_store_dot_product_with_docs): pipeline = DocumentSearchPipeline(retriever) prediction = pipeline.run("Paul lives in New York") scores = [document.score for document in prediction["documents"]] - assert scores == pytest.approx([0.5526493562767626, 0.5189836204008691, 0.5179697571274173], abs=1e-3) + assert scores == pytest.approx( + [0.5526494403409358, 0.5247784342375555, 0.5189836829440964, 0.5179697273254912, 0.5112024928228626], abs=1e-3 + ) def test_custom_headers(document_store_with_docs: BaseDocumentStore): diff --git a/test/test_eval.py b/test/test_eval.py index 0c79f145d8..27e6ff2f18 100644 --- a/test/test_eval.py +++ b/test/test_eval.py @@ -43,7 +43,7 @@ def test_generativeqa_calculate_metrics( assert metrics["Retriever"]["map"] == 0.5 assert metrics["Retriever"]["recall_multi_hit"] == 0.5 assert metrics["Retriever"]["recall_single_hit"] == 0.5 - assert metrics["Retriever"]["precision"] == 1.0 / 6 + assert metrics["Retriever"]["precision"] == 0.1 assert metrics["Retriever"]["ndcg"] == 0.5 assert metrics["Generator"]["exact_match"] == 0.0 assert metrics["Generator"]["f1"] == 1.0 / 3 @@ -70,13 +70,13 @@ def test_summarizer_calculate_metrics( assert metrics["Retriever"]["map"] == 0.5 assert metrics["Retriever"]["recall_multi_hit"] == 0.5 assert metrics["Retriever"]["recall_single_hit"] == 0.5 - assert metrics["Retriever"]["precision"] == 1.0 / 6 + assert metrics["Retriever"]["precision"] == 0.1 assert metrics["Retriever"]["ndcg"] == 0.5 assert metrics["Summarizer"]["mrr"] == 0.5 assert metrics["Summarizer"]["map"] == 0.5 assert metrics["Summarizer"]["recall_multi_hit"] == 0.5 assert metrics["Summarizer"]["recall_single_hit"] == 0.5 - assert metrics["Summarizer"]["precision"] == 1.0 / 6 + assert metrics["Summarizer"]["precision"] == 0.1 assert metrics["Summarizer"]["ndcg"] == 0.5 @@ -325,7 +325,7 @@ def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path): assert metrics["Retriever"]["mrr"] == 1.0 assert metrics["Retriever"]["recall_multi_hit"] == 1.0 assert metrics["Retriever"]["recall_single_hit"] == 1.0 - assert metrics["Retriever"]["precision"] == 1.0 / 3 + assert metrics["Retriever"]["precision"] == 0.2 assert metrics["Retriever"]["map"] == 1.0 assert metrics["Retriever"]["ndcg"] == 1.0 @@ -346,7 +346,7 @@ def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path): assert metrics["Retriever"]["mrr"] == 1.0 assert metrics["Retriever"]["recall_multi_hit"] == 1.0 assert metrics["Retriever"]["recall_single_hit"] == 1.0 - assert metrics["Retriever"]["precision"] == 1.0 / 3 + assert metrics["Retriever"]["precision"] == 0.2 assert metrics["Retriever"]["map"] == 1.0 assert metrics["Retriever"]["ndcg"] == 1.0 @@ -390,7 +390,7 @@ def test_extractive_qa_eval_multiple_queries(reader, retriever_with_docs, tmp_pa assert metrics["Retriever"]["map"] == 0.5 assert metrics["Retriever"]["recall_multi_hit"] == 0.5 assert metrics["Retriever"]["recall_single_hit"] == 0.5 - assert metrics["Retriever"]["precision"] == 1.0 / 6 + assert metrics["Retriever"]["precision"] == 0.1 assert metrics["Retriever"]["ndcg"] == 0.5 eval_result.save(tmp_path) @@ -419,7 +419,7 @@ def test_extractive_qa_eval_multiple_queries(reader, retriever_with_docs, tmp_pa assert metrics["Retriever"]["map"] == 0.5 assert metrics["Retriever"]["recall_multi_hit"] == 0.5 assert metrics["Retriever"]["recall_single_hit"] == 0.5 - assert metrics["Retriever"]["precision"] == 1.0 / 6 + assert metrics["Retriever"]["precision"] == 0.1 assert metrics["Retriever"]["ndcg"] == 0.5 @@ -441,7 +441,7 @@ def test_extractive_qa_eval_sas(reader, retriever_with_docs): assert metrics["Retriever"]["map"] == 0.5 assert metrics["Retriever"]["recall_multi_hit"] == 0.5 assert metrics["Retriever"]["recall_single_hit"] == 0.5 - assert metrics["Retriever"]["precision"] == 1.0 / 6 + assert metrics["Retriever"]["precision"] == 0.1 assert metrics["Retriever"]["ndcg"] == 0.5 assert "sas" in metrics["Reader"] assert metrics["Reader"]["sas"] == pytest.approx(1.0) @@ -462,7 +462,7 @@ def test_extractive_qa_eval_doc_relevance_col(reader, retriever_with_docs): assert metrics["Retriever"]["map"] == 0.75 assert metrics["Retriever"]["recall_multi_hit"] == 0.75 assert metrics["Retriever"]["recall_single_hit"] == 1.0 - assert metrics["Retriever"]["precision"] == 1.0 / 3 + assert metrics["Retriever"]["precision"] == 0.2 assert metrics["Retriever"]["ndcg"] == pytest.approx(0.8066, 1e-4) @@ -485,7 +485,7 @@ def test_extractive_qa_eval_simulated_top_k_reader(reader, retriever_with_docs): assert metrics_top_1["Retriever"]["map"] == 0.5 assert metrics_top_1["Retriever"]["recall_multi_hit"] == 0.5 assert metrics_top_1["Retriever"]["recall_single_hit"] == 0.5 - assert metrics_top_1["Retriever"]["precision"] == 1.0 / 6 + assert metrics_top_1["Retriever"]["precision"] == 0.1 assert metrics_top_1["Retriever"]["ndcg"] == 0.5 metrics_top_2 = eval_result.calculate_metrics(simulated_top_k_reader=2) @@ -497,7 +497,7 @@ def test_extractive_qa_eval_simulated_top_k_reader(reader, retriever_with_docs): assert metrics_top_2["Retriever"]["map"] == 0.5 assert metrics_top_2["Retriever"]["recall_multi_hit"] == 0.5 assert metrics_top_2["Retriever"]["recall_single_hit"] == 0.5 - assert metrics_top_2["Retriever"]["precision"] == 1.0 / 6 + assert metrics_top_2["Retriever"]["precision"] == 0.1 assert metrics_top_2["Retriever"]["ndcg"] == 0.5 metrics_top_3 = eval_result.calculate_metrics(simulated_top_k_reader=3) @@ -509,7 +509,7 @@ def test_extractive_qa_eval_simulated_top_k_reader(reader, retriever_with_docs): assert metrics_top_3["Retriever"]["map"] == 0.5 assert metrics_top_3["Retriever"]["recall_multi_hit"] == 0.5 assert metrics_top_3["Retriever"]["recall_single_hit"] == 0.5 - assert metrics_top_3["Retriever"]["precision"] == 1.0 / 6 + assert metrics_top_3["Retriever"]["precision"] == 0.1 assert metrics_top_3["Retriever"]["ndcg"] == 0.5 @@ -527,7 +527,7 @@ def test_extractive_qa_eval_simulated_top_k_retriever(reader, retriever_with_doc assert metrics_top_10["Retriever"]["map"] == 0.5 assert metrics_top_10["Retriever"]["recall_multi_hit"] == 0.5 assert metrics_top_10["Retriever"]["recall_single_hit"] == 0.5 - assert metrics_top_10["Retriever"]["precision"] == 1.0 / 6 + assert metrics_top_10["Retriever"]["precision"] == 0.1 assert metrics_top_10["Retriever"]["ndcg"] == 0.5 metrics_top_1 = eval_result.calculate_metrics(simulated_top_k_retriever=1) @@ -578,7 +578,7 @@ def test_extractive_qa_eval_simulated_top_k_reader_and_retriever(reader, retriev assert metrics_top_10["Retriever"]["map"] == 0.5 assert metrics_top_10["Retriever"]["recall_multi_hit"] == 0.5 assert metrics_top_10["Retriever"]["recall_single_hit"] == 0.5 - assert metrics_top_10["Retriever"]["precision"] == 1.0 / 6 + assert metrics_top_10["Retriever"]["precision"] == 0.1 assert metrics_top_10["Retriever"]["ndcg"] == 0.5 metrics_top_1 = eval_result.calculate_metrics(simulated_top_k_reader=1, simulated_top_k_retriever=1) @@ -634,7 +634,7 @@ def test_extractive_qa_eval_isolated(reader, retriever_with_docs): assert metrics_top_1["Retriever"]["map"] == 0.5 assert metrics_top_1["Retriever"]["recall_multi_hit"] == 0.5 assert metrics_top_1["Retriever"]["recall_single_hit"] == 0.5 - assert metrics_top_1["Retriever"]["precision"] == 1.0 / 6 + assert metrics_top_1["Retriever"]["precision"] == 1.0 / 10 assert metrics_top_1["Retriever"]["ndcg"] == 0.5 metrics_top_1 = eval_result.calculate_metrics(simulated_top_k_reader=1, eval_mode="isolated") @@ -768,7 +768,7 @@ def test_document_search_calculate_metrics(retriever_with_docs): assert metrics["Retriever"]["map"] == 0.5 assert metrics["Retriever"]["recall_multi_hit"] == 0.5 assert metrics["Retriever"]["recall_single_hit"] == 0.5 - assert metrics["Retriever"]["precision"] == 1.0 / 6 + assert metrics["Retriever"]["precision"] == 0.1 assert metrics["Retriever"]["ndcg"] == 0.5 @@ -788,7 +788,7 @@ def test_faq_calculate_metrics(retriever_with_docs): assert metrics["Retriever"]["map"] == 0.5 assert metrics["Retriever"]["recall_multi_hit"] == 0.5 assert metrics["Retriever"]["recall_single_hit"] == 0.5 - assert metrics["Retriever"]["precision"] == 1.0 / 6 + assert metrics["Retriever"]["precision"] == 0.1 assert metrics["Retriever"]["ndcg"] == 0.5 assert metrics["Docs2Answers"]["exact_match"] == 0.0 assert metrics["Docs2Answers"]["f1"] == 0.0 @@ -816,7 +816,7 @@ def test_extractive_qa_eval_translation(reader, retriever_with_docs, de_to_en_tr assert metrics["Retriever"]["map"] == 0.5 assert metrics["Retriever"]["recall_multi_hit"] == 0.5 assert metrics["Retriever"]["recall_single_hit"] == 0.5 - assert metrics["Retriever"]["precision"] == 1.0 / 6 + assert metrics["Retriever"]["precision"] == 0.1 assert metrics["Retriever"]["ndcg"] == 0.5 assert metrics["OutputTranslator"]["exact_match"] == 1.0 @@ -825,7 +825,7 @@ def test_extractive_qa_eval_translation(reader, retriever_with_docs, de_to_en_tr assert metrics["OutputTranslator"]["map"] == 0.5 assert metrics["OutputTranslator"]["recall_multi_hit"] == 0.5 assert metrics["OutputTranslator"]["recall_single_hit"] == 0.5 - assert metrics["OutputTranslator"]["precision"] == 1.0 / 6 + assert metrics["OutputTranslator"]["precision"] == 0.1 assert metrics["OutputTranslator"]["ndcg"] == 0.5 @@ -846,14 +846,14 @@ def test_question_generation_eval(retriever_with_docs, question_generator): assert metrics["Retriever"]["map"] == 0.5 assert metrics["Retriever"]["recall_multi_hit"] == 0.5 assert metrics["Retriever"]["recall_single_hit"] == 0.5 - assert metrics["Retriever"]["precision"] == 1.0 / 6 + assert metrics["Retriever"]["precision"] == 0.1 assert metrics["Retriever"]["ndcg"] == 0.5 assert metrics["Question Generator"]["mrr"] == 0.5 assert metrics["Question Generator"]["map"] == 0.5 assert metrics["Question Generator"]["recall_multi_hit"] == 0.5 assert metrics["Question Generator"]["recall_single_hit"] == 0.5 - assert metrics["Question Generator"]["precision"] == 1.0 / 6 + assert metrics["Question Generator"]["precision"] == 0.1 assert metrics["Question Generator"]["ndcg"] == 0.5 @@ -907,14 +907,14 @@ def test_qa_multi_retriever_pipeline_eval(document_store_with_docs, reader): assert metrics["DPRRetriever"]["map"] == 0.5 assert metrics["DPRRetriever"]["recall_multi_hit"] == 0.5 assert metrics["DPRRetriever"]["recall_single_hit"] == 0.5 - assert metrics["DPRRetriever"]["precision"] == 1.0 / 6 + assert metrics["DPRRetriever"]["precision"] == 0.1 assert metrics["DPRRetriever"]["ndcg"] == 0.5 assert metrics["ESRetriever"]["mrr"] == 1.0 assert metrics["ESRetriever"]["map"] == 1.0 assert metrics["ESRetriever"]["recall_multi_hit"] == 1.0 assert metrics["ESRetriever"]["recall_single_hit"] == 1.0 - assert metrics["ESRetriever"]["precision"] == 1.0 / 3 + assert metrics["ESRetriever"]["precision"] == 0.2 assert metrics["ESRetriever"]["ndcg"] == 1.0 assert metrics["QAReader"]["exact_match"] == 1.0 @@ -969,14 +969,14 @@ def test_multi_retriever_pipeline_eval(document_store_with_docs, reader): assert metrics["DPRRetriever"]["map"] == 0.5 assert metrics["DPRRetriever"]["recall_multi_hit"] == 0.5 assert metrics["DPRRetriever"]["recall_single_hit"] == 0.5 - assert metrics["DPRRetriever"]["precision"] == 1.0 / 6 + assert metrics["DPRRetriever"]["precision"] == 0.1 assert metrics["DPRRetriever"]["ndcg"] == 0.5 assert metrics["ESRetriever"]["mrr"] == 1.0 assert metrics["ESRetriever"]["map"] == 1.0 assert metrics["ESRetriever"]["recall_multi_hit"] == 1.0 assert metrics["ESRetriever"]["recall_single_hit"] == 1.0 - assert metrics["ESRetriever"]["precision"] == 1.0 / 3 + assert metrics["ESRetriever"]["precision"] == 0.2 assert metrics["ESRetriever"]["ndcg"] == 1.0 @@ -1031,14 +1031,14 @@ def test_multi_retriever_pipeline_with_asymmetric_qa_eval(document_store_with_do assert metrics["DPRRetriever"]["map"] == 0.5 assert metrics["DPRRetriever"]["recall_multi_hit"] == 0.5 assert metrics["DPRRetriever"]["recall_single_hit"] == 0.5 - assert metrics["DPRRetriever"]["precision"] == 1.0 / 6 + assert metrics["DPRRetriever"]["precision"] == 0.1 assert metrics["DPRRetriever"]["ndcg"] == 0.5 assert metrics["ESRetriever"]["mrr"] == 1.0 assert metrics["ESRetriever"]["map"] == 1.0 assert metrics["ESRetriever"]["recall_multi_hit"] == 1.0 assert metrics["ESRetriever"]["recall_single_hit"] == 1.0 - assert metrics["ESRetriever"]["precision"] == 1.0 / 3 + assert metrics["ESRetriever"]["precision"] == 0.2 assert metrics["ESRetriever"]["ndcg"] == 1.0 assert metrics["QAReader"]["exact_match"] == 1.0 diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 4ef3921b69..ebb235510e 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -441,7 +441,7 @@ def test_documentsearch_es_authentication(retriever_with_docs, document_store_wi params={"Retriever": {"top_k": 10, "headers": auth_headers}}, ) assert prediction is not None - assert len(prediction["documents"]) == 3 + assert len(prediction["documents"]) == 5 mock_client.search.assert_called_once() args, kwargs = mock_client.search.call_args assert "headers" in kwargs @@ -470,7 +470,7 @@ def test_documentsearch_document_store_authentication(retriever_with_docs, docum params={"Retriever": {"top_k": 10, "headers": auth_headers}}, ) assert prediction is not None - assert len(prediction["documents"]) == 3 + assert len(prediction["documents"]) == 5 mock_client.count.assert_called_once() args, kwargs = mock_client.count.call_args assert "headers" in kwargs diff --git a/test/test_reader.py b/test/test_reader.py index 2ac16ab0e4..971336eb9b 100644 --- a/test/test_reader.py +++ b/test/test_reader.py @@ -28,7 +28,7 @@ def test_output(prediction): def test_no_answer_output(no_answer_prediction): assert no_answer_prediction is not None assert no_answer_prediction["query"] == "What is the meaning of life?" - assert math.isclose(no_answer_prediction["no_ans_gap"], -13.048564434051514, rel_tol=0.0001) + assert math.isclose(no_answer_prediction["no_ans_gap"], -11.847594738006592, rel_tol=0.0001) assert no_answer_prediction["answers"][0].answer == "" assert no_answer_prediction["answers"][0].offsets_in_context[0].start == 0 assert no_answer_prediction["answers"][0].offsets_in_context[0].end == 0 @@ -89,7 +89,9 @@ def test_context_window_size(reader, test_docs_xs, window_size): if len(answer.answer) <= window_size: assert len(answer.context) in [window_size, window_size - 1] else: - assert len(answer.answer) == len(answer.context) + # If the extracted answer is larger than the context window and is odd in length, + # the resulting context window is one more than the answer length + assert len(answer.context) in [len(answer.answer), len(answer.answer) + 1] reader.inferencer.model.prediction_heads[0].context_window_size = old_window_size diff --git a/test/test_retriever.py b/test/test_retriever.py index e48d454f92..02fce8df51 100644 --- a/test/test_retriever.py +++ b/test/test_retriever.py @@ -74,7 +74,7 @@ def test_retrieval(retriever_with_docs, document_store_with_docs): # test without filters res = retriever_with_docs.retrieve(query="Who lives in Berlin?") assert res[0].content == "My name is Carla and I live in Berlin" - assert len(res) == 3 + assert len(res) == 5 assert res[0].meta["name"] == "filename1" # test with filters diff --git a/test/test_standard_pipelines.py b/test/test_standard_pipelines.py index 5ae8248c40..e14267b3b6 100644 --- a/test/test_standard_pipelines.py +++ b/test/test_standard_pipelines.py @@ -156,7 +156,7 @@ def test_join_document_pipeline(document_store_dot_product_with_docs, reader): p.add_node(component=dpr, name="R2", inputs=["Query"]) p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) results = p.run(query=query) - assert len(results["documents"]) == 3 + assert len(results["documents"]) == 5 # test merge with weights join_node = JoinDocuments(join_mode="merge", weights=[1000, 1], top_k_join=2) @@ -165,7 +165,7 @@ def test_join_document_pipeline(document_store_dot_product_with_docs, reader): p.add_node(component=dpr, name="R2", inputs=["Query"]) p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) results = p.run(query=query) - assert math.isclose(results["documents"][0].score, 0.5350644373470798, rel_tol=0.0001) + assert math.isclose(results["documents"][0].score, 0.5481393431183286, rel_tol=0.0001) assert len(results["documents"]) == 2 # test concatenate @@ -175,7 +175,7 @@ def test_join_document_pipeline(document_store_dot_product_with_docs, reader): p.add_node(component=dpr, name="R2", inputs=["Query"]) p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) results = p.run(query=query) - assert len(results["documents"]) == 3 + assert len(results["documents"]) == 5 # test concatenate with top_k_join parameter join_node = JoinDocuments(join_mode="concatenate")