From bf5ce515686dcb7593a3ee4a213367fa740cc582 Mon Sep 17 00:00:00 2001 From: Patrick Deutschmann Date: Tue, 17 May 2022 09:51:22 +0200 Subject: [PATCH 01/10] Implement MDR --- docs/_src/api/api/retriever.md | 323 ++++++++++++++++++ .../haystack-pipeline-master.schema.json | 146 ++++++++ haystack/modeling/model/prediction_head.py | 4 +- haystack/nodes/__init__.py | 1 + haystack/nodes/retriever/__init__.py | 7 +- haystack/nodes/retriever/dense.py | 312 +++++++++++++++++ test/conftest.py | 14 +- test/nodes/test_retriever.py | 11 +- 8 files changed, 814 insertions(+), 4 deletions(-) diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index 5e9f90f71b..19a9553d6e 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -876,6 +876,329 @@ def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max Load DensePassageRetriever from the specified directory. + + +## MultihopDenseRetriever + +```python +class MultihopDenseRetriever(BaseRetriever) +``` + +Retriever that applies iterative retrieval using a shared encoder for query and passage. +See original paper for more details: + +Xiong, Wenhan, et. al. (2020): "Answering complex open-domain questions with multi-hop dense retrieval" +(https://arxiv.org/abs/2009.12756) + + + +#### MultihopDenseRetriever.\_\_init\_\_ + +```python +def __init__(document_store: BaseDocumentStore, embedding_model: Union[Path, str] = "deutschmann/mdr_roberta_q_encoder", model_version: Optional[str] = None, num_iterations: int = 2, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True) +``` + +Init the Retriever incl. the encoder model from a local or remote model checkpoint. + +The checkpoint format matches huggingface transformers' model format + +**Example:** + + ```python + | # remote model + | MultihopDenseRetriever(document_store=your_doc_store, + | embedding_model="deutschmann/mdr_roberta_q_encoder") + | # or from local path + | MultihopDenseRetriever(document_store=your_doc_store, + | embedding_model="model_directory/encoder") + ``` + +**Arguments**: + +- `document_store`: An instance of DocumentStore from which to retrieve documents. +- `query_embedding_model`: Local path or remote name of encoder checkpoint. The format equals the +one used by hugging-face transformers' modelhub models +Currently available remote names: ``"deutschmann/mdr_roberta_q_encoder", "facebook/dpr-ctx_encoder-single-nq-base"`` +- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. +- `num_iterations`: The number of times passages are retrieved, i.e., the number of hops (Defaults to 2.) +- `max_seq_len_query`: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down." +- `max_seq_len_passage`: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down." +- `top_k`: How many documents to return per query. +- `use_gpu`: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available. +- `batch_size`: Number of questions or passages to encode at once. In case of multiple gpus, this will be the total batch size. +- `embed_title`: Whether to concatenate title and passage to a text pair that is then used to create the embedding. +This is the approach used in the original paper and is likely to improve performance if your +titles contain meaningful information for retrieval (topic, entities etc.) . +The title is expected to be present in doc.meta["name"] and can be supplied in the documents +before writing them to the DocumentStore like this: +{"text": "my text", "meta": {"name": "my title"}}. +- `use_fast_tokenizers`: Whether to use fast Rust tokenizers +- `infer_tokenizer_classes`: Whether to infer tokenizer class from the model config / name. +If `False`, the class always loads `RobertaTokenizer`. +- `similarity_function`: Which function to apply for calculating the similarity of query and passage embeddings during training. +Options: `dot_product` (Default) or `cosine` +- `global_loss_buffer_size`: Buffer size for all_gather() in DDP. +Increase if errors like "encoded data exceeds max_size ..." come up +- `progress_bar`: Whether to show a tqdm progress bar or not. +Can be helpful to disable in production deployments to keep the logs clean. +- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones +These strings will be converted into pytorch devices, so use the string notation described here: +https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device +(e.g. ["cuda:0"]). +- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`, +the local token will be used, which must be previously created via `transformer-cli login`. +Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained +- `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]). +If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. +Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + + + +#### MultihopDenseRetriever.retrieve + +```python +def retrieve(query: str, filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, top_k: Optional[int] = None, index: str = None, headers: Optional[Dict[str, str]] = None, scale_score: bool = None) -> List[Document] +``` + +Scan through documents in DocumentStore and return a small number documents + +that are most relevant to the query. + +**Arguments**: + +- `query`: The query +- `filters`: Optional filters to narrow down the search space to documents whose metadata fulfill certain +conditions. +Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical +operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, +`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. +Logical operator keys take a dictionary of metadata field names and/or logical operators as +value. Metadata field names take a dictionary of comparison operators as value. Comparison +operator keys take a single value or (in case of `"$in"`) a list of values as value. +If no logical operator is provided, `"$and"` is used as default operation. If no comparison +operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default +operation. + + __Example__: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + __Example__: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` +- `top_k`: How many documents to return per query. +- `index`: The name of the index in the DocumentStore from which to retrieve documents +- `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]). +If true similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. +Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + + + +#### MultihopDenseRetriever.retrieve\_batch + +```python +def retrieve_batch(queries: Union[str, List[str]], filters: Optional[ + Union[ + Dict[str, Union[Dict, List, str, int, float, bool]], + List[Dict[str, Union[Dict, List, str, int, float, bool]]], + ] + ] = None, top_k: Optional[int] = None, index: str = None, headers: Optional[Dict[str, str]] = None, batch_size: Optional[int] = None, scale_score: bool = None) -> Union[List[Document], List[List[Document]]] +``` + +Scan through documents in DocumentStore and return a small number documents + +that are most relevant to the supplied queries. + +If you supply a single query, a single list of Documents is returned. If you supply a list of queries, a list of +lists of Documents (one per query) is returned. + +**Arguments**: + +- `queries`: Single query string or list of queries. +- `filters`: Optional filters to narrow down the search space to documents whose metadata fulfill certain +conditions. Can be a single filter that will be applied to each query or a list of filters +(one filter per query). + +Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical +operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, +`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. +Logical operator keys take a dictionary of metadata field names and/or logical operators as +value. Metadata field names take a dictionary of comparison operators as value. Comparison +operator keys take a single value or (in case of `"$in"`) a list of values as value. +If no logical operator is provided, `"$and"` is used as default operation. If no comparison +operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default +operation. + + __Example__: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + __Example__: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` +- `top_k`: How many documents to return per query. +- `index`: The name of the index in the DocumentStore from which to retrieve documents +- `batch_size`: Number of queries to embed at a time. +- `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]). +If true similarity scores (e.g. cosine or dot_product) which naturally have a different +value range will be scaled to a range of [0,1], where 1 means extremely relevant. +Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + + + +#### MultihopDenseRetriever.embed\_queries + +```python +def embed_queries(queries: List[str], contexts: List[List[Document]]) -> List[np.ndarray] +``` + +Create embeddings for a list of queries using the query encoder + +**Arguments**: + +- `queries`: Queries to embed +- `contexts`: Context documents + +**Returns**: + +Embeddings, one per input queries + + + +#### MultihopDenseRetriever.embed\_documents + +```python +def embed_documents(docs: List[Document]) -> List[np.ndarray] +``` + +Create embeddings for a list of documents using the passage encoder + +**Arguments**: + +- `docs`: List of Document objects used to represent documents / passages in a standardized way within Haystack. + +**Returns**: + +Embeddings of documents / passages shape (batch_size, embedding_dim) + + + +#### MultihopDenseRetriever.save + +```python +def save(save_dir: Union[Path, str], encoder_dir: str = "encoder") +``` + +Save MultihopDenseRetriever to the specified directory. + +**Arguments**: + +- `save_dir`: Directory to save to. +- `encoder_dir`: Directory in save_dir that contains encoder model. + +**Returns**: + +None + + + +#### MultihopDenseRetriever.load + +```python +@classmethod +def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", encoder_dir: str = "encoder", infer_tokenizer_classes: bool = False) +``` + +Load MultihopDenseRetriever from the specified directory. + ## TableTextRetriever diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json index ff55116564..b629b1491f 100644 --- a/haystack/json-schemas/haystack-pipeline-master.schema.json +++ b/haystack/json-schemas/haystack-pipeline-master.schema.json @@ -124,6 +124,9 @@ { "$ref": "#/definitions/MarkdownConverterComponent" }, + { + "$ref": "#/definitions/MultihopDenseRetrieverComponent" + }, { "$ref": "#/definitions/PDFToTextConverterComponent" }, @@ -3083,6 +3086,149 @@ ], "additionalProperties": false }, + "MultihopDenseRetrieverComponent": { + "type": "object", + "properties": { + "name": { + "title": "Name", + "description": "Custom name for the component. Helpful for visualization and debugging.", + "type": "string" + }, + "type": { + "title": "Type", + "description": "Haystack Class name for the component.", + "type": "string", + "const": "MultihopDenseRetriever" + }, + "params": { + "title": "Parameters", + "type": "object", + "properties": { + "document_store": { + "title": "Document Store", + "type": "string" + }, + "embedding_model": { + "title": "Embedding Model", + "default": "deutschmann/mdr_roberta_q_encoder", + "anyOf": [ + { + "type": "string", + "format": "path" + }, + { + "type": "string" + } + ] + }, + "model_version": { + "title": "Model Version", + "type": "string" + }, + "num_iterations": { + "title": "Num Iterations", + "default": 2, + "type": "integer" + }, + "max_seq_len_query": { + "title": "Max Seq Len Query", + "default": 64, + "type": "integer" + }, + "max_seq_len_passage": { + "title": "Max Seq Len Passage", + "default": 256, + "type": "integer" + }, + "top_k": { + "title": "Top K", + "default": 10, + "type": "integer" + }, + "use_gpu": { + "title": "Use Gpu", + "default": true, + "type": "boolean" + }, + "batch_size": { + "title": "Batch Size", + "default": 16, + "type": "integer" + }, + "embed_title": { + "title": "Embed Title", + "default": true, + "type": "boolean" + }, + "use_fast_tokenizers": { + "title": "Use Fast Tokenizers", + "default": true, + "type": "boolean" + }, + "infer_tokenizer_classes": { + "title": "Infer Tokenizer Classes", + "default": false, + "type": "boolean" + }, + "similarity_function": { + "title": "Similarity Function", + "default": "dot_product", + "type": "string" + }, + "global_loss_buffer_size": { + "title": "Global Loss Buffer Size", + "default": 150000, + "type": "integer" + }, + "progress_bar": { + "title": "Progress Bar", + "default": true, + "type": "boolean" + }, + "devices": { + "title": "Devices", + "type": "array", + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "string" + } + ] + } + }, + "use_auth_token": { + "title": "Use Auth Token", + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "string" + } + ] + }, + "scale_score": { + "title": "Scale Score", + "default": true, + "type": "boolean" + } + }, + "required": [ + "document_store" + ], + "additionalProperties": false, + "description": "Each parameter can reference other components defined in the same YAML file." + } + }, + "required": [ + "type", + "name" + ], + "additionalProperties": false + }, "PDFToTextConverterComponent": { "type": "object", "properties": { diff --git a/haystack/modeling/model/prediction_head.py b/haystack/modeling/model/prediction_head.py index b66891b17f..6654dae80f 100644 --- a/haystack/modeling/model/prediction_head.py +++ b/haystack/modeling/model/prediction_head.py @@ -971,7 +971,9 @@ def get_similarity_function(self): f"The similarity function can only be 'dot_product' or 'cosine', not '{self.similarity_function}'" ) - def forward(self, query_vectors: torch.Tensor, passage_vectors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, query_vectors: torch.Tensor, passage_vectors: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Only packs the embeddings from both language models into a tuple. No further modification. The similarity calculation is handled later to enable distributed training (DDP) diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py index 6b51be67fc..9f861293fc 100644 --- a/haystack/nodes/__init__.py +++ b/haystack/nodes/__init__.py @@ -34,6 +34,7 @@ BM25Retriever, ElasticsearchRetriever, FilterRetriever, + MultihopDenseRetriever, ElasticsearchFilterOnlyRetriever, TfidfRetriever, Text2SparqlRetriever, diff --git a/haystack/nodes/retriever/__init__.py b/haystack/nodes/retriever/__init__.py index 547537dad0..d3ff1976ba 100644 --- a/haystack/nodes/retriever/__init__.py +++ b/haystack/nodes/retriever/__init__.py @@ -1,5 +1,10 @@ from haystack.nodes.retriever.base import BaseRetriever -from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever +from haystack.nodes.retriever.dense import ( + DensePassageRetriever, + EmbeddingRetriever, + MultihopDenseRetriever, + TableTextRetriever, +) from haystack.nodes.retriever.sparse import ( BM25Retriever, ElasticsearchRetriever, diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 834cd46011..a5a003a33d 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1898,3 +1898,315 @@ def save(self, save_dir: Union[Path, str]) -> None: :type save_dir: Union[Path, str] """ self.embedding_encoder.save(save_dir=save_dir) + + +class MultihopDenseRetriever(EmbeddingRetriever): + """ + Retriever that applies iterative retrieval using a shared encoder for query and passage. + See original paper for more details: + + Xiong, Wenhan, et. al. (2020): "Answering complex open-domain questions with multi-hop dense retrieval" + (https://arxiv.org/abs/2009.12756) + """ + + def __init__( + self, + document_store: BaseDocumentStore, + embedding_model: str, + model_version: Optional[str] = None, + num_iterations: int = 2, + use_gpu: bool = True, + batch_size: int = 32, + max_seq_len: int = 512, + model_format: str = "farm", + pooling_strategy: str = "reduce_mean", + emb_extraction_layer: int = -1, + top_k: int = 10, + progress_bar: bool = True, + devices: Optional[List[Union[str, torch.device]]] = None, + use_auth_token: Optional[Union[str, bool]] = None, + scale_score: bool = True, + embed_meta_fields: List[str] = [], + ): + """ + Same parameters as `EmbeddingRetriever` except + + :param num_iterations: The number of times passages are retrieved, i.e., the number of hops (Defaults to 2.) + """ + super().__init__( + document_store, + embedding_model, + model_version, + use_gpu, + batch_size, + max_seq_len, + model_format, + pooling_strategy, + emb_extraction_layer, + top_k, + progress_bar, + devices, + use_auth_token, + scale_score, + embed_meta_fields, + ) + self.num_iterations = num_iterations + + def _merge_query_and_context(self, query: str, context: List[Document], sep: str = " "): + return sep.join([query] + [doc.content for doc in context]) + + def retrieve( + self, + query: str, + filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, + top_k: Optional[int] = None, + index: str = None, + headers: Optional[Dict[str, str]] = None, + scale_score: bool = None, + ) -> List[Document]: + """ + Scan through documents in DocumentStore and return a small number documents + that are most relevant to the query. + + :param query: The query + :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain + conditions. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + __Example__: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + __Example__: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` + :param top_k: How many documents to return per query. + :param index: The name of the index in the DocumentStore from which to retrieve documents + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If true similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + """ + return self.retrieve_batch( # type: ignore + queries=query, + filters=[filters] if filters is not None else None, + top_k=top_k, + index=index, + headers=headers, + scale_score=scale_score, + batch_size=1, + ) + + def retrieve_batch( + self, + queries: Union[str, List[str]], + filters: Optional[ + Union[ + Dict[str, Union[Dict, List, str, int, float, bool]], + List[Dict[str, Union[Dict, List, str, int, float, bool]]], + ] + ] = None, + top_k: Optional[int] = None, + index: str = None, + headers: Optional[Dict[str, str]] = None, + batch_size: Optional[int] = None, + scale_score: bool = None, + ) -> Union[List[Document], List[List[Document]]]: + """ + Scan through documents in DocumentStore and return a small number documents + that are most relevant to the supplied queries. + + If you supply a single query, a single list of Documents is returned. If you supply a list of queries, a list of + lists of Documents (one per query) is returned. + + :param queries: Single query string or list of queries. + :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain + conditions. Can be a single filter that will be applied to each query or a list of filters + (one filter per query). + + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + + __Example__: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + __Example__: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` + :param top_k: How many documents to return per query. + :param index: The name of the index in the DocumentStore from which to retrieve documents + :param batch_size: Number of queries to embed at a time. + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If true similarity scores (e.g. cosine or dot_product) which naturally have a different + value range will be scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + """ + + if top_k is None: + top_k = self.top_k + + if batch_size is None: + batch_size = self.batch_size + + single_query = False + if isinstance(queries, str): + queries = [queries] + single_query = True + + if isinstance(filters, list): + if len(filters) != len(queries): + raise HaystackError( + "Number of filters does not match number of queries. Please provide as many filters" + " as queries or a single filter that will be applied to each query." + ) + else: + filters = [{}] * len(queries) + + if index is None: + index = self.document_store.index + if scale_score is None: + scale_score = self.scale_score + if not self.document_store: + logger.error( + "Cannot perform retrieve_batch() since MultihopDenseRetriever initialized with document_store=None" + ) + if single_query: + single_result: List[Document] = [] + return single_result + else: + result: List[List[Document]] = [[] * len(queries)] + return result + + documents = [] + batches = self._get_batches(queries=queries, batch_size=batch_size) + # TODO: Currently filters are applied both for final and context documents. + # maybe they should only apply for final docs? or make it configurable with a param? + for batch, cur_filters in zip(batches, filters): + context_docs: List[List[Document]] = [[] for _ in range(len(batch))] + for it in range(self.num_iterations): + texts = [self._merge_query_and_context(q, c) for q, c in zip(batch, context_docs)] + query_embs = self.embed_queries(texts) + for idx, emb in enumerate(query_embs): + cur_docs = self.document_store.query_by_embedding( + query_emb=emb, + top_k=top_k, + filters=cur_filters, + index=index, + headers=headers, + scale_score=scale_score, + ) + if it < self.num_iterations - 1: + # add doc with highest score to context + if len(cur_docs) > 0: + context_docs[idx].append(cur_docs[0]) + else: + # documents in the last iteration are final results + documents.append(cur_docs) + + if single_query: + return documents[0] + else: + return documents diff --git a/test/conftest.py b/test/conftest.py index 3fd78c197f..305b20a1ad 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -55,7 +55,12 @@ from haystack.nodes.ranker import SentenceTransformersRanker from haystack.nodes.document_classifier.transformers import TransformersDocumentClassifier from haystack.nodes.retriever.sparse import FilterRetriever, BM25Retriever, TfidfRetriever -from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever +from haystack.nodes.retriever.dense import ( + DensePassageRetriever, + EmbeddingRetriever, + MultihopDenseRetriever, + TableTextRetriever, +) from haystack.nodes.reader.farm import FARMReader from haystack.nodes.reader.transformers import TransformersReader from haystack.nodes.reader.table import TableReader, RCIReader @@ -672,6 +677,13 @@ def get_retriever(retriever_type, document_store): use_gpu=False, embed_title=True, ) + elif retriever_type == "mdr": + retriever = MultihopDenseRetriever( + document_store=document_store, + embedding_model="deutschmann/mdr_roberta_q_encoder", # or "facebook/dpr-ctx_encoder-single-nq-base" + use_gpu=False, + embed_title=True, + ) elif retriever_type == "tfidf": retriever = TfidfRetriever(document_store=document_store) retriever.fit() diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index b3fd71ff2f..e3e903336e 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -15,7 +15,12 @@ from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore from haystack.document_stores.faiss import FAISSDocumentStore from haystack.document_stores import MilvusDocumentStore -from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever +from haystack.nodes.retriever.dense import ( + DensePassageRetriever, + EmbeddingRetriever, + TableTextRetriever, + MultihopDenseRetriever, +) from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast, PreTrainedTokenizerFast @@ -26,6 +31,10 @@ @pytest.mark.parametrize( "retriever_with_docs,document_store_with_docs", [ + ("mdr", "elasticsearch"), + ("mdr", "faiss"), + ("mdr", "memory"), + ("mdr", "milvus1"), ("dpr", "elasticsearch"), ("dpr", "faiss"), ("dpr", "memory"), From 05f33745b8f5f5eb53ad74cb3d26edc974325bc8 Mon Sep 17 00:00:00 2001 From: Patrick Deutschmann Date: Thu, 23 Jun 2022 14:44:31 +0200 Subject: [PATCH 02/10] Adapt conftest to new MDR signature --- test/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/conftest.py b/test/conftest.py index 305b20a1ad..f8ad87b7c8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -682,7 +682,6 @@ def get_retriever(retriever_type, document_store): document_store=document_store, embedding_model="deutschmann/mdr_roberta_q_encoder", # or "facebook/dpr-ctx_encoder-single-nq-base" use_gpu=False, - embed_title=True, ) elif retriever_type == "tfidf": retriever = TfidfRetriever(document_store=document_store) From c4c1dc598b1bb17da07938f837093640e9b84061 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 23 Jun 2022 12:52:14 +0000 Subject: [PATCH 03/10] Update Documentation & Code Style --- docs/_src/api/api/retriever.md | 531 +++++++----------- .../haystack-pipeline-master.schema.json | 77 +-- 2 files changed, 239 insertions(+), 369 deletions(-) diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index 19a9553d6e..f5422ec318 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -876,329 +876,6 @@ def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max Load DensePassageRetriever from the specified directory. - - -## MultihopDenseRetriever - -```python -class MultihopDenseRetriever(BaseRetriever) -``` - -Retriever that applies iterative retrieval using a shared encoder for query and passage. -See original paper for more details: - -Xiong, Wenhan, et. al. (2020): "Answering complex open-domain questions with multi-hop dense retrieval" -(https://arxiv.org/abs/2009.12756) - - - -#### MultihopDenseRetriever.\_\_init\_\_ - -```python -def __init__(document_store: BaseDocumentStore, embedding_model: Union[Path, str] = "deutschmann/mdr_roberta_q_encoder", model_version: Optional[str] = None, num_iterations: int = 2, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True) -``` - -Init the Retriever incl. the encoder model from a local or remote model checkpoint. - -The checkpoint format matches huggingface transformers' model format - -**Example:** - - ```python - | # remote model - | MultihopDenseRetriever(document_store=your_doc_store, - | embedding_model="deutschmann/mdr_roberta_q_encoder") - | # or from local path - | MultihopDenseRetriever(document_store=your_doc_store, - | embedding_model="model_directory/encoder") - ``` - -**Arguments**: - -- `document_store`: An instance of DocumentStore from which to retrieve documents. -- `query_embedding_model`: Local path or remote name of encoder checkpoint. The format equals the -one used by hugging-face transformers' modelhub models -Currently available remote names: ``"deutschmann/mdr_roberta_q_encoder", "facebook/dpr-ctx_encoder-single-nq-base"`` -- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. -- `num_iterations`: The number of times passages are retrieved, i.e., the number of hops (Defaults to 2.) -- `max_seq_len_query`: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down." -- `max_seq_len_passage`: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down." -- `top_k`: How many documents to return per query. -- `use_gpu`: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available. -- `batch_size`: Number of questions or passages to encode at once. In case of multiple gpus, this will be the total batch size. -- `embed_title`: Whether to concatenate title and passage to a text pair that is then used to create the embedding. -This is the approach used in the original paper and is likely to improve performance if your -titles contain meaningful information for retrieval (topic, entities etc.) . -The title is expected to be present in doc.meta["name"] and can be supplied in the documents -before writing them to the DocumentStore like this: -{"text": "my text", "meta": {"name": "my title"}}. -- `use_fast_tokenizers`: Whether to use fast Rust tokenizers -- `infer_tokenizer_classes`: Whether to infer tokenizer class from the model config / name. -If `False`, the class always loads `RobertaTokenizer`. -- `similarity_function`: Which function to apply for calculating the similarity of query and passage embeddings during training. -Options: `dot_product` (Default) or `cosine` -- `global_loss_buffer_size`: Buffer size for all_gather() in DDP. -Increase if errors like "encoded data exceeds max_size ..." come up -- `progress_bar`: Whether to show a tqdm progress bar or not. -Can be helpful to disable in production deployments to keep the logs clean. -- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones -These strings will be converted into pytorch devices, so use the string notation described here: -https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device -(e.g. ["cuda:0"]). -- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`, -the local token will be used, which must be previously created via `transformer-cli login`. -Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained -- `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]). -If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. -Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. - - - -#### MultihopDenseRetriever.retrieve - -```python -def retrieve(query: str, filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, top_k: Optional[int] = None, index: str = None, headers: Optional[Dict[str, str]] = None, scale_score: bool = None) -> List[Document] -``` - -Scan through documents in DocumentStore and return a small number documents - -that are most relevant to the query. - -**Arguments**: - -- `query`: The query -- `filters`: Optional filters to narrow down the search space to documents whose metadata fulfill certain -conditions. -Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical -operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, -`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. -Logical operator keys take a dictionary of metadata field names and/or logical operators as -value. Metadata field names take a dictionary of comparison operators as value. Comparison -operator keys take a single value or (in case of `"$in"`) a list of values as value. -If no logical operator is provided, `"$and"` is used as default operation. If no comparison -operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default -operation. - - __Example__: - ```python - filters = { - "$and": { - "type": {"$eq": "article"}, - "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, - "rating": {"$gte": 3}, - "$or": { - "genre": {"$in": ["economy", "politics"]}, - "publisher": {"$eq": "nytimes"} - } - } - } - # or simpler using default operators - filters = { - "type": "article", - "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, - "rating": {"$gte": 3}, - "$or": { - "genre": ["economy", "politics"], - "publisher": "nytimes" - } - } - ``` - - To use the same logical operator multiple times on the same level, logical operators take - optionally a list of dictionaries as value. - - __Example__: - ```python - filters = { - "$or": [ - { - "$and": { - "Type": "News Paper", - "Date": { - "$lt": "2019-01-01" - } - } - }, - { - "$and": { - "Type": "Blog Post", - "Date": { - "$gte": "2019-01-01" - } - } - } - ] - } - ``` -- `top_k`: How many documents to return per query. -- `index`: The name of the index in the DocumentStore from which to retrieve documents -- `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]). -If true similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. -Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. - - - -#### MultihopDenseRetriever.retrieve\_batch - -```python -def retrieve_batch(queries: Union[str, List[str]], filters: Optional[ - Union[ - Dict[str, Union[Dict, List, str, int, float, bool]], - List[Dict[str, Union[Dict, List, str, int, float, bool]]], - ] - ] = None, top_k: Optional[int] = None, index: str = None, headers: Optional[Dict[str, str]] = None, batch_size: Optional[int] = None, scale_score: bool = None) -> Union[List[Document], List[List[Document]]] -``` - -Scan through documents in DocumentStore and return a small number documents - -that are most relevant to the supplied queries. - -If you supply a single query, a single list of Documents is returned. If you supply a list of queries, a list of -lists of Documents (one per query) is returned. - -**Arguments**: - -- `queries`: Single query string or list of queries. -- `filters`: Optional filters to narrow down the search space to documents whose metadata fulfill certain -conditions. Can be a single filter that will be applied to each query or a list of filters -(one filter per query). - -Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical -operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, -`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. -Logical operator keys take a dictionary of metadata field names and/or logical operators as -value. Metadata field names take a dictionary of comparison operators as value. Comparison -operator keys take a single value or (in case of `"$in"`) a list of values as value. -If no logical operator is provided, `"$and"` is used as default operation. If no comparison -operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default -operation. - - __Example__: - ```python - filters = { - "$and": { - "type": {"$eq": "article"}, - "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, - "rating": {"$gte": 3}, - "$or": { - "genre": {"$in": ["economy", "politics"]}, - "publisher": {"$eq": "nytimes"} - } - } - } - # or simpler using default operators - filters = { - "type": "article", - "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, - "rating": {"$gte": 3}, - "$or": { - "genre": ["economy", "politics"], - "publisher": "nytimes" - } - } - ``` - - To use the same logical operator multiple times on the same level, logical operators take - optionally a list of dictionaries as value. - - __Example__: - ```python - filters = { - "$or": [ - { - "$and": { - "Type": "News Paper", - "Date": { - "$lt": "2019-01-01" - } - } - }, - { - "$and": { - "Type": "Blog Post", - "Date": { - "$gte": "2019-01-01" - } - } - } - ] - } - ``` -- `top_k`: How many documents to return per query. -- `index`: The name of the index in the DocumentStore from which to retrieve documents -- `batch_size`: Number of queries to embed at a time. -- `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]). -If true similarity scores (e.g. cosine or dot_product) which naturally have a different -value range will be scaled to a range of [0,1], where 1 means extremely relevant. -Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. - - - -#### MultihopDenseRetriever.embed\_queries - -```python -def embed_queries(queries: List[str], contexts: List[List[Document]]) -> List[np.ndarray] -``` - -Create embeddings for a list of queries using the query encoder - -**Arguments**: - -- `queries`: Queries to embed -- `contexts`: Context documents - -**Returns**: - -Embeddings, one per input queries - - - -#### MultihopDenseRetriever.embed\_documents - -```python -def embed_documents(docs: List[Document]) -> List[np.ndarray] -``` - -Create embeddings for a list of documents using the passage encoder - -**Arguments**: - -- `docs`: List of Document objects used to represent documents / passages in a standardized way within Haystack. - -**Returns**: - -Embeddings of documents / passages shape (batch_size, embedding_dim) - - - -#### MultihopDenseRetriever.save - -```python -def save(save_dir: Union[Path, str], encoder_dir: str = "encoder") -``` - -Save MultihopDenseRetriever to the specified directory. - -**Arguments**: - -- `save_dir`: Directory to save to. -- `encoder_dir`: Directory in save_dir that contains encoder model. - -**Returns**: - -None - - - -#### MultihopDenseRetriever.load - -```python -@classmethod -def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", encoder_dir: str = "encoder", infer_tokenizer_classes: bool = False) -``` - -Load MultihopDenseRetriever from the specified directory. - ## TableTextRetriever @@ -1795,6 +1472,214 @@ Save the model to the given directory - `save_dir` (`Union[Path, str]`): The directory where the model will be saved + + +## MultihopDenseRetriever + +```python +class MultihopDenseRetriever(EmbeddingRetriever) +``` + +Retriever that applies iterative retrieval using a shared encoder for query and passage. +See original paper for more details: + +Xiong, Wenhan, et. al. (2020): "Answering complex open-domain questions with multi-hop dense retrieval" +(https://arxiv.org/abs/2009.12756) + + + +#### MultihopDenseRetriever.\_\_init\_\_ + +```python +def __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, num_iterations: int = 2, use_gpu: bool = True, batch_size: int = 32, max_seq_len: int = 512, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True, embed_meta_fields: List[str] = []) +``` + +Same parameters as `EmbeddingRetriever` except + +**Arguments**: + +- `num_iterations`: The number of times passages are retrieved, i.e., the number of hops (Defaults to 2.) + + + +#### MultihopDenseRetriever.retrieve + +```python +def retrieve(query: str, filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, top_k: Optional[int] = None, index: str = None, headers: Optional[Dict[str, str]] = None, scale_score: bool = None) -> List[Document] +``` + +Scan through documents in DocumentStore and return a small number documents + +that are most relevant to the query. + +**Arguments**: + +- `query`: The query +- `filters`: Optional filters to narrow down the search space to documents whose metadata fulfill certain +conditions. +Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical +operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, +`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. +Logical operator keys take a dictionary of metadata field names and/or logical operators as +value. Metadata field names take a dictionary of comparison operators as value. Comparison +operator keys take a single value or (in case of `"$in"`) a list of values as value. +If no logical operator is provided, `"$and"` is used as default operation. If no comparison +operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default +operation. + + __Example__: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + __Example__: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` +- `top_k`: How many documents to return per query. +- `index`: The name of the index in the DocumentStore from which to retrieve documents +- `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]). +If true similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. +Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + + + +#### MultihopDenseRetriever.retrieve\_batch + +```python +def retrieve_batch(queries: Union[str, List[str]], filters: Optional[ + Union[ + Dict[str, Union[Dict, List, str, int, float, bool]], + List[Dict[str, Union[Dict, List, str, int, float, bool]]], + ] + ] = None, top_k: Optional[int] = None, index: str = None, headers: Optional[Dict[str, str]] = None, batch_size: Optional[int] = None, scale_score: bool = None) -> Union[List[Document], List[List[Document]]] +``` + +Scan through documents in DocumentStore and return a small number documents + +that are most relevant to the supplied queries. + +If you supply a single query, a single list of Documents is returned. If you supply a list of queries, a list of +lists of Documents (one per query) is returned. + +**Arguments**: + +- `queries`: Single query string or list of queries. +- `filters`: Optional filters to narrow down the search space to documents whose metadata fulfill certain +conditions. Can be a single filter that will be applied to each query or a list of filters +(one filter per query). + +Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical +operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, +`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. +Logical operator keys take a dictionary of metadata field names and/or logical operators as +value. Metadata field names take a dictionary of comparison operators as value. Comparison +operator keys take a single value or (in case of `"$in"`) a list of values as value. +If no logical operator is provided, `"$and"` is used as default operation. If no comparison +operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default +operation. + + __Example__: + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators take + optionally a list of dictionaries as value. + + __Example__: + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` +- `top_k`: How many documents to return per query. +- `index`: The name of the index in the DocumentStore from which to retrieve documents +- `batch_size`: Number of queries to embed at a time. +- `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]). +If true similarity scores (e.g. cosine or dot_product) which naturally have a different +value range will be scaled to a range of [0,1], where 1 means extremely relevant. +Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + # Module text2sparql diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json index b629b1491f..3ff8b231b6 100644 --- a/haystack/json-schemas/haystack-pipeline-master.schema.json +++ b/haystack/json-schemas/haystack-pipeline-master.schema.json @@ -3110,16 +3110,7 @@ }, "embedding_model": { "title": "Embedding Model", - "default": "deutschmann/mdr_roberta_q_encoder", - "anyOf": [ - { - "type": "string", - "format": "path" - }, - { - "type": "string" - } - ] + "type": "string" }, "model_version": { "title": "Model Version", @@ -3130,21 +3121,6 @@ "default": 2, "type": "integer" }, - "max_seq_len_query": { - "title": "Max Seq Len Query", - "default": 64, - "type": "integer" - }, - "max_seq_len_passage": { - "title": "Max Seq Len Passage", - "default": 256, - "type": "integer" - }, - "top_k": { - "title": "Top K", - "default": 10, - "type": "integer" - }, "use_gpu": { "title": "Use Gpu", "default": true, @@ -3152,32 +3128,32 @@ }, "batch_size": { "title": "Batch Size", - "default": 16, + "default": 32, "type": "integer" }, - "embed_title": { - "title": "Embed Title", - "default": true, - "type": "boolean" - }, - "use_fast_tokenizers": { - "title": "Use Fast Tokenizers", - "default": true, - "type": "boolean" + "max_seq_len": { + "title": "Max Seq Len", + "default": 512, + "type": "integer" }, - "infer_tokenizer_classes": { - "title": "Infer Tokenizer Classes", - "default": false, - "type": "boolean" + "model_format": { + "title": "Model Format", + "default": "farm", + "type": "string" }, - "similarity_function": { - "title": "Similarity Function", - "default": "dot_product", + "pooling_strategy": { + "title": "Pooling Strategy", + "default": "reduce_mean", "type": "string" }, - "global_loss_buffer_size": { - "title": "Global Loss Buffer Size", - "default": 150000, + "emb_extraction_layer": { + "title": "Emb Extraction Layer", + "default": -1, + "type": "integer" + }, + "top_k": { + "title": "Top K", + "default": 10, "type": "integer" }, "progress_bar": { @@ -3214,10 +3190,19 @@ "title": "Scale Score", "default": true, "type": "boolean" + }, + "embed_meta_fields": { + "title": "Embed Meta Fields", + "default": [], + "type": "array", + "items": { + "type": "string" + } } }, "required": [ - "document_store" + "document_store", + "embedding_model" ], "additionalProperties": false, "description": "Each parameter can reference other components defined in the same YAML file." From 7f50a6fad8fdde0a5a61130d509cf445ceeaa8a8 Mon Sep 17 00:00:00 2001 From: Patrick Deutschmann Date: Thu, 23 Jun 2022 15:09:41 +0200 Subject: [PATCH 04/10] Change signature of queries param in batch methods of MDR like in #2575 --- haystack/nodes/retriever/dense.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index a5a003a33d..33df06f2c9 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -2038,19 +2038,19 @@ def retrieve( If true similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. """ - return self.retrieve_batch( # type: ignore - queries=query, + return self.retrieve_batch( + queries=[query], filters=[filters] if filters is not None else None, top_k=top_k, index=index, headers=headers, scale_score=scale_score, batch_size=1, - ) + )[0] def retrieve_batch( self, - queries: Union[str, List[str]], + queries: List[str], filters: Optional[ Union[ Dict[str, Union[Dict, List, str, int, float, bool]], @@ -2062,7 +2062,7 @@ def retrieve_batch( headers: Optional[Dict[str, str]] = None, batch_size: Optional[int] = None, scale_score: bool = None, - ) -> Union[List[Document], List[List[Document]]]: + ) -> List[List[Document]]: """ Scan through documents in DocumentStore and return a small number documents that are most relevant to the supplied queries. @@ -2151,11 +2151,6 @@ def retrieve_batch( if batch_size is None: batch_size = self.batch_size - single_query = False - if isinstance(queries, str): - queries = [queries] - single_query = True - if isinstance(filters, list): if len(filters) != len(queries): raise HaystackError( @@ -2173,12 +2168,8 @@ def retrieve_batch( logger.error( "Cannot perform retrieve_batch() since MultihopDenseRetriever initialized with document_store=None" ) - if single_query: - single_result: List[Document] = [] - return single_result - else: - result: List[List[Document]] = [[] * len(queries)] - return result + result: List[List[Document]] = [[] * len(queries)] + return result documents = [] batches = self._get_batches(queries=queries, batch_size=batch_size) @@ -2206,7 +2197,4 @@ def retrieve_batch( # documents in the last iteration are final results documents.append(cur_docs) - if single_query: - return documents[0] - else: - return documents + return documents From 40710ad7868e798f126bd449b09246d41fa1da67 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 23 Jun 2022 13:12:22 +0000 Subject: [PATCH 05/10] Update Documentation & Code Style --- docs/_src/api/api/retriever.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index f5422ec318..e7e32c6e9e 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -1589,12 +1589,12 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. #### MultihopDenseRetriever.retrieve\_batch ```python -def retrieve_batch(queries: Union[str, List[str]], filters: Optional[ +def retrieve_batch(queries: List[str], filters: Optional[ Union[ Dict[str, Union[Dict, List, str, int, float, bool]], List[Dict[str, Union[Dict, List, str, int, float, bool]]], ] - ] = None, top_k: Optional[int] = None, index: str = None, headers: Optional[Dict[str, str]] = None, batch_size: Optional[int] = None, scale_score: bool = None) -> Union[List[Document], List[List[Document]]] + ] = None, top_k: Optional[int] = None, index: str = None, headers: Optional[Dict[str, str]] = None, batch_size: Optional[int] = None, scale_score: bool = None) -> List[List[Document]] ``` Scan through documents in DocumentStore and return a small number documents From 8a7382825d02f788015c56e0e28f477c0284c96c Mon Sep 17 00:00:00 2001 From: Patrick Deutschmann Date: Tue, 5 Jul 2022 08:34:05 +0200 Subject: [PATCH 06/10] Rename MultihopDenseRetriever to MultihopEmbeddingRetriever --- docs/_src/api/api/retriever.md | 18 +++++++++--------- .../haystack-pipeline-master.schema.json | 6 +++--- haystack/nodes/__init__.py | 2 +- haystack/nodes/retriever/__init__.py | 2 +- haystack/nodes/retriever/dense.py | 4 ++-- test/conftest.py | 4 ++-- test/nodes/test_retriever.py | 2 +- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index e7e32c6e9e..d30c102c88 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -1472,12 +1472,12 @@ Save the model to the given directory - `save_dir` (`Union[Path, str]`): The directory where the model will be saved - + -## MultihopDenseRetriever +## MultihopEmbeddingRetriever ```python -class MultihopDenseRetriever(EmbeddingRetriever) +class MultihopEmbeddingRetriever(EmbeddingRetriever) ``` Retriever that applies iterative retrieval using a shared encoder for query and passage. @@ -1486,9 +1486,9 @@ See original paper for more details: Xiong, Wenhan, et. al. (2020): "Answering complex open-domain questions with multi-hop dense retrieval" (https://arxiv.org/abs/2009.12756) - + -#### MultihopDenseRetriever.\_\_init\_\_ +#### MultihopEmbeddingRetriever.\_\_init\_\_ ```python def __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, num_iterations: int = 2, use_gpu: bool = True, batch_size: int = 32, max_seq_len: int = 512, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True, embed_meta_fields: List[str] = []) @@ -1500,9 +1500,9 @@ Same parameters as `EmbeddingRetriever` except - `num_iterations`: The number of times passages are retrieved, i.e., the number of hops (Defaults to 2.) - + -#### MultihopDenseRetriever.retrieve +#### MultihopEmbeddingRetriever.retrieve ```python def retrieve(query: str, filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, top_k: Optional[int] = None, index: str = None, headers: Optional[Dict[str, str]] = None, scale_score: bool = None) -> List[Document] @@ -1584,9 +1584,9 @@ operation. If true similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. - + -#### MultihopDenseRetriever.retrieve\_batch +#### MultihopEmbeddingRetriever.retrieve\_batch ```python def retrieve_batch(queries: List[str], filters: Optional[ diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json index 3ff8b231b6..b3efe05eff 100644 --- a/haystack/json-schemas/haystack-pipeline-master.schema.json +++ b/haystack/json-schemas/haystack-pipeline-master.schema.json @@ -125,7 +125,7 @@ "$ref": "#/definitions/MarkdownConverterComponent" }, { - "$ref": "#/definitions/MultihopDenseRetrieverComponent" + "$ref": "#/definitions/MultihopEmbeddingRetrieverComponent" }, { "$ref": "#/definitions/PDFToTextConverterComponent" @@ -3086,7 +3086,7 @@ ], "additionalProperties": false }, - "MultihopDenseRetrieverComponent": { + "MultihopEmbeddingRetrieverComponent": { "type": "object", "properties": { "name": { @@ -3098,7 +3098,7 @@ "title": "Type", "description": "Haystack Class name for the component.", "type": "string", - "const": "MultihopDenseRetriever" + "const": "MultihopEmbeddingRetriever" }, "params": { "title": "Parameters", diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py index 9f861293fc..1d512a16c6 100644 --- a/haystack/nodes/__init__.py +++ b/haystack/nodes/__init__.py @@ -34,7 +34,7 @@ BM25Retriever, ElasticsearchRetriever, FilterRetriever, - MultihopDenseRetriever, + MultihopEmbeddingRetriever, ElasticsearchFilterOnlyRetriever, TfidfRetriever, Text2SparqlRetriever, diff --git a/haystack/nodes/retriever/__init__.py b/haystack/nodes/retriever/__init__.py index d3ff1976ba..f1805f6f6b 100644 --- a/haystack/nodes/retriever/__init__.py +++ b/haystack/nodes/retriever/__init__.py @@ -2,7 +2,7 @@ from haystack.nodes.retriever.dense import ( DensePassageRetriever, EmbeddingRetriever, - MultihopDenseRetriever, + MultihopEmbeddingRetriever, TableTextRetriever, ) from haystack.nodes.retriever.sparse import ( diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 33df06f2c9..5d40bdb3e5 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1900,7 +1900,7 @@ def save(self, save_dir: Union[Path, str]) -> None: self.embedding_encoder.save(save_dir=save_dir) -class MultihopDenseRetriever(EmbeddingRetriever): +class MultihopEmbeddingRetriever(EmbeddingRetriever): """ Retriever that applies iterative retrieval using a shared encoder for query and passage. See original paper for more details: @@ -2166,7 +2166,7 @@ def retrieve_batch( scale_score = self.scale_score if not self.document_store: logger.error( - "Cannot perform retrieve_batch() since MultihopDenseRetriever initialized with document_store=None" + "Cannot perform retrieve_batch() since MultihopEmbeddingRetriever initialized with document_store=None" ) result: List[List[Document]] = [[] * len(queries)] return result diff --git a/test/conftest.py b/test/conftest.py index f8ad87b7c8..37aac8cdac 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -58,7 +58,7 @@ from haystack.nodes.retriever.dense import ( DensePassageRetriever, EmbeddingRetriever, - MultihopDenseRetriever, + MultihopEmbeddingRetriever, TableTextRetriever, ) from haystack.nodes.reader.farm import FARMReader @@ -678,7 +678,7 @@ def get_retriever(retriever_type, document_store): embed_title=True, ) elif retriever_type == "mdr": - retriever = MultihopDenseRetriever( + retriever = MultihopEmbeddingRetriever( document_store=document_store, embedding_model="deutschmann/mdr_roberta_q_encoder", # or "facebook/dpr-ctx_encoder-single-nq-base" use_gpu=False, diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index e3e903336e..f652f9647f 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -19,7 +19,7 @@ DensePassageRetriever, EmbeddingRetriever, TableTextRetriever, - MultihopDenseRetriever, + MultihopEmbeddingRetriever, ) from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast, PreTrainedTokenizerFast From 856029a33fc123b25a51d0c7568469d50d908e07 Mon Sep 17 00:00:00 2001 From: Patrick Deutschmann Date: Tue, 5 Jul 2022 08:36:41 +0200 Subject: [PATCH 07/10] Fix filters in retrieve_batch --- haystack/nodes/retriever/dense.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 5d40bdb3e5..d7a90c0dd7 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -2158,7 +2158,7 @@ def retrieve_batch( " as queries or a single filter that will be applied to each query." ) else: - filters = [{}] * len(queries) + filters = [filters] * len(queries) if filters is not None else [{}] * len(queries) if index is None: index = self.document_store.index From 74c9c2f53a2121e9f9afa1438edc429f22ebd6f5 Mon Sep 17 00:00:00 2001 From: Patrick Deutschmann Date: Tue, 5 Jul 2022 08:38:22 +0200 Subject: [PATCH 08/10] Add docstring for MultihopEmbeddingRetriever.__init__ --- haystack/nodes/retriever/dense.py | 43 +++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index d7a90c0dd7..63a74e9cc2 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1929,9 +1929,48 @@ def __init__( embed_meta_fields: List[str] = [], ): """ - Same parameters as `EmbeddingRetriever` except - + :param document_store: An instance of DocumentStore from which to retrieve documents. + :param embedding_model: Local path or name of model in Hugging Face's model hub such as ``'sentence-transformers/all-MiniLM-L6-v2'`` + :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. :param num_iterations: The number of times passages are retrieved, i.e., the number of hops (Defaults to 2.) + :param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available. + :param batch_size: Number of documents to encode at once. + :param max_seq_len: Longest length of each document sequence. Maximum number of tokens for the document text. Longer ones will be cut down. + :param model_format: Name of framework that was used for saving the model or model type. If no model_format is + provided, it will be inferred automatically from the model configuration files. + Options: + + - ``'farm'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder) + - ``'transformers'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder) + - ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder) + - ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder) + :param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only). + Options: + + - ``'cls_token'`` (sentence vector) + - ``'reduce_mean'`` (sentence vector) + - ``'reduce_max'`` (sentence vector) + - ``'per_token'`` (individual token vectors) + :param emb_extraction_layer: Number of layer from which the embeddings shall be extracted (for farm / transformers models only). + Default: -1 (very last layer). + :param top_k: How many documents to return per query. + :param progress_bar: If true displays progress bar during embedding. + :param devices: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones + These strings will be converted into pytorch devices, so use the string notation described here: + https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device + (e.g. ["cuda:0"]). Note: As multi-GPU training is currently not implemented for EmbeddingRetriever, + training will only use the first device provided in this list. + :param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`, + the local token will be used, which must be previously created via `transformer-cli login`. + Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + :param embed_meta_fields: Concatenate the provided meta fields and text passage / table to a text pair that is + then used to create the embedding. + This approach is also used in the TableTextRetriever paper and is likely to improve + performance if your titles contain meaningful information for retrieval + (topic, entities etc.). """ super().__init__( document_store, From 919b63baf4c2d6c5b321f710167aa70a09cca297 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 5 Jul 2022 06:43:22 +0000 Subject: [PATCH 09/10] Update Documentation & Code Style --- docs/_src/api/api/retriever.md | 43 ++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index d30c102c88..a600983127 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -1494,11 +1494,50 @@ Xiong, Wenhan, et. al. (2020): "Answering complex open-domain questions with mul def __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, num_iterations: int = 2, use_gpu: bool = True, batch_size: int = 32, max_seq_len: int = 512, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True, embed_meta_fields: List[str] = []) ``` -Same parameters as `EmbeddingRetriever` except - **Arguments**: +- `document_store`: An instance of DocumentStore from which to retrieve documents. +- `embedding_model`: Local path or name of model in Hugging Face's model hub such as ``'sentence-transformers/all-MiniLM-L6-v2'`` +- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. - `num_iterations`: The number of times passages are retrieved, i.e., the number of hops (Defaults to 2.) +- `use_gpu`: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available. +- `batch_size`: Number of documents to encode at once. +- `max_seq_len`: Longest length of each document sequence. Maximum number of tokens for the document text. Longer ones will be cut down. +- `model_format`: Name of framework that was used for saving the model or model type. If no model_format is +provided, it will be inferred automatically from the model configuration files. +Options: + +- ``'farm'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder) +- ``'transformers'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder) +- ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder) +- ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder) +- `pooling_strategy`: Strategy for combining the embeddings from the model (for farm / transformers models only). +Options: + +- ``'cls_token'`` (sentence vector) +- ``'reduce_mean'`` (sentence vector) +- ``'reduce_max'`` (sentence vector) +- ``'per_token'`` (individual token vectors) +- `emb_extraction_layer`: Number of layer from which the embeddings shall be extracted (for farm / transformers models only). +Default: -1 (very last layer). +- `top_k`: How many documents to return per query. +- `progress_bar`: If true displays progress bar during embedding. +- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones +These strings will be converted into pytorch devices, so use the string notation described here: +https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device +(e.g. ["cuda:0"]). Note: As multi-GPU training is currently not implemented for EmbeddingRetriever, +training will only use the first device provided in this list. +- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`, +the local token will be used, which must be previously created via `transformer-cli login`. +Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained +- `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]). +If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. +Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. +- `embed_meta_fields`: Concatenate the provided meta fields and text passage / table to a text pair that is +then used to create the embedding. +This approach is also used in the TableTextRetriever paper and is likely to improve +performance if your titles contain meaningful information for retrieval +(topic, entities etc.). From b584816ea34939406e919b072892507c29012949 Mon Sep 17 00:00:00 2001 From: Patrick Deutschmann Date: Tue, 5 Jul 2022 08:47:51 +0200 Subject: [PATCH 10/10] Revert forward signature of TextSimilarityHead --- haystack/modeling/model/prediction_head.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/haystack/modeling/model/prediction_head.py b/haystack/modeling/model/prediction_head.py index 6654dae80f..b66891b17f 100644 --- a/haystack/modeling/model/prediction_head.py +++ b/haystack/modeling/model/prediction_head.py @@ -971,9 +971,7 @@ def get_similarity_function(self): f"The similarity function can only be 'dot_product' or 'cosine', not '{self.similarity_function}'" ) - def forward( - self, query_vectors: torch.Tensor, passage_vectors: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def forward(self, query_vectors: torch.Tensor, passage_vectors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Only packs the embeddings from both language models into a tuple. No further modification. The similarity calculation is handled later to enable distributed training (DDP)