diff --git a/docs/_src/usage/usage/retriever.md b/docs/_src/usage/usage/retriever.md index 0b759adb7b..122ef38338 100644 --- a/docs/_src/usage/usage/retriever.md +++ b/docs/_src/usage/usage/retriever.md @@ -122,6 +122,16 @@ Indexing using DPR is comparatively expensive in terms of required computation s The embeddings that are created in this step can be stored in FAISS, a database optimized for vector similarity. DPR can also work with the ElasticsearchDocumentStore or the InMemoryDocumentStore. +
+ +**Tip** + +When using DPR, it is recommended that you use the dot product similarity function since that is how it is trained. +To do so, simply provide `similarity="dot_product"` when initializing the DocumentStore +as is done in the code example below. + +
+ There are two design decisions that have made DPR particularly performant. @@ -136,7 +146,7 @@ If you’d like to learn how to set up a DPR based system, have a look at our tu ### Initialisation ```python -document_store = FAISSDocumentStore() +document_store = FAISSDocumentStore(similarity="dot_product") ... retriever = DensePassageRetriever( document_store=document_store, @@ -161,10 +171,20 @@ They are particular suited to cases where your query input is similar in style t i.e. when you are searching for most similar documents. This is not inherently suited to query based search where the length, language and format of the query usually significantly differs from the searched for text. +
+ +**Tip** + +When using Sentence Transformer models, we recommend that you use a cosine similarity function. +To do so, simply provide `similarity="cosine"` when initializing the DocumentStore +as is done in the code example below. + +
+ ### Initialisation ```python -document_store = ElasticsearchDocumentStore() +document_store = ElasticsearchDocumentStore(similarity="cosine") ... retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert") diff --git a/haystack/document_store/base.py b/haystack/document_store/base.py index a3edbb9bff..1f33b825fe 100644 --- a/haystack/document_store/base.py +++ b/haystack/document_store/base.py @@ -12,6 +12,7 @@ class BaseDocumentStore(ABC): """ index: Optional[str] label_index: Optional[str] + similarity: Optional[str] @abstractmethod def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None): diff --git a/haystack/document_store/elasticsearch.py b/haystack/document_store/elasticsearch.py index 7619151e57..bbd4600f55 100644 --- a/haystack/document_store/elasticsearch.py +++ b/haystack/document_store/elasticsearch.py @@ -117,6 +117,7 @@ def __init__( self.update_existing_documents = update_existing_documents self.refresh_type = refresh_type + self.similarity = similarity if similarity == "cosine": self.similarity_fn_name = "cosineSimilarity" elif similarity == "dot_product": @@ -596,7 +597,10 @@ def _convert_es_hit_to_document( if score: if adapt_score_for_embedding: score -= 1000 - probability = (score + 1) / 2 # scaling probability from cosine similarity + if self.similarity == "cosine": + probability = (score + 1) / 2 # scaling probability from cosine similarity + elif self.similarity == "dot_product": + probability = float(expit(np.asarray(score / 100))) # scaling probability from dot product else: probability = float(expit(np.asarray(score / 8))) # scaling probability from TFIDF/BM25 else: diff --git a/haystack/document_store/faiss.py b/haystack/document_store/faiss.py index 9cbdb24895..60aca27b79 100644 --- a/haystack/document_store/faiss.py +++ b/haystack/document_store/faiss.py @@ -9,6 +9,8 @@ from haystack.document_store.sql import SQLDocumentStore from haystack.retriever.base import BaseRetriever +from scipy.special import expit + if platform != 'win32' and platform != 'cygwin': import faiss else: @@ -39,6 +41,7 @@ def __init__( return_embedding: bool = False, update_existing_documents: bool = False, index: str = "document", + similarity: str = "dot_product", **kwargs, ): """ @@ -82,6 +85,11 @@ def __init__( self.index_buffer_size = index_buffer_size self.return_embedding = return_embedding + if similarity == "dot_product": + self.similarity = similarity + else: + raise ValueError("The FAISS document store can currently only support dot_product similarity. " + "Please set similarity=\"dot_product\"") super().__init__( url=sql_url, update_existing_documents=update_existing_documents, @@ -116,7 +124,8 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O # doc + metadata index index = index or self.index - document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents] + field_map = self._create_document_field_map() + document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents] add_vectors = False if document_objects[0].embedding is None else True @@ -142,6 +151,11 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index) + def _create_document_field_map(self) -> Dict: + return { + self.index: "embedding", + } + def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None): """ Updates the embeddings in the the document store using the encoding model specified in the retriever. @@ -273,7 +287,7 @@ def query_by_embedding(self, scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])} for doc in documents: doc.score = scores_for_vector_ids[doc.meta["vector_id"]] - doc.probability = (doc.score + 1) / 2 + doc.probability = float(expit(np.asarray(doc.score / 100))) if return_embedding is True: doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"])) diff --git a/haystack/document_store/memory.py b/haystack/document_store/memory.py index e5646100f1..7cf223c586 100644 --- a/haystack/document_store/memory.py +++ b/haystack/document_store/memory.py @@ -8,6 +8,8 @@ from haystack.preprocessor.utils import eval_data_from_file from haystack.retriever.base import BaseRetriever +from scipy.spatial.distance import cosine + import logging logger = logging.getLogger(__name__) @@ -17,13 +19,14 @@ class InMemoryDocumentStore(BaseDocumentStore): In-memory document store """ - def __init__(self, embedding_field: Optional[str] = "embedding", return_embedding: bool = False): + def __init__(self, embedding_field: Optional[str] = "embedding", return_embedding: bool = False, similarity="dot_product"): self.indexes: Dict[str, Dict] = defaultdict(dict) self.index: str = "document" self.label_index: str = "label" self.embedding_field: str = embedding_field if embedding_field is not None else "embedding" self.embedding_dim: int = 768 self.return_embedding: bool = return_embedding + self.similarity: str = similarity def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None): """ @@ -41,11 +44,17 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O """ index = index or self.index - documents_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents] + field_map = self._create_document_field_map() + documents_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents] for document in documents_objects: self.indexes[index][document.id] = document + def _create_document_field_map(self): + return { + self.embedding_field: "embedding", + } + def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[str] = None): """Write annotation labels into document store.""" index = index or self.label_index @@ -106,18 +115,24 @@ def query_by_embedding(self, candidate_docs = [] for idx, doc in self.indexes[index].items(): + curr_meta = deepcopy(doc.meta) new_document = Document( id=doc.id, text=doc.text, - meta=deepcopy(doc.meta) + meta=curr_meta, + embedding=doc.embedding ) new_document.embedding = doc.embedding if return_embedding is True else None - score = dot(query_emb, doc.embedding) / ( - norm(query_emb) * norm(doc.embedding) - ) + + if self.similarity == "dot_product": + score = dot(query_emb, doc.embedding) / ( + norm(query_emb) * norm(doc.embedding) + ) + elif self.similarity == "cosine": + # cosine similarity score = 1 - cosine distance + score = 1 - cosine(query_emb, doc.embedding) new_document.score = score new_document.probability = (score + 1) / 2 - candidate_docs.append(new_document) return sorted(candidate_docs, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)[0:top_k] diff --git a/haystack/document_store/sql.py b/haystack/document_store/sql.py index 12ceb9c772..55510f5e17 100644 --- a/haystack/document_store/sql.py +++ b/haystack/document_store/sql.py @@ -87,6 +87,8 @@ def __init__( self.index = index self.label_index = label_index self.update_existing_documents = update_existing_documents + if getattr(self, "similarity", None) is None: + self.similarity = None def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]: """Fetch a document by specifying its text id string""" diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index b37b4ceb8e..404c297200 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -6,6 +6,7 @@ from tqdm import tqdm from haystack.document_store.base import BaseDocumentStore +from haystack.document_store.elasticsearch import ElasticsearchDocumentStore from haystack import Document from haystack.retriever.base import BaseRetriever @@ -86,6 +87,11 @@ def __init__(self, self.max_seq_len_passage = max_seq_len_passage self.max_seq_len_query = max_seq_len_query + if document_store.similarity != "dot_product": + logger.warning(f"You are using a Dense Passage Retriever model with the {document_store.similarity} function. " + "We recommend you use dot_product instead. " + "This can be set when initializing the DocumentStore") + if use_gpu and torch.cuda.is_available(): self.device = torch.device("cuda") else: @@ -399,6 +405,18 @@ def __init__( embedding_model, task_type="embeddings", extraction_strategy=self.pooling_strategy, extraction_layer=self.emb_extraction_layer, gpu=use_gpu, batch_size=4, max_seq_len=512, num_processes=0 ) + # Check that document_store has the right similarity function + similarity = document_store.similarity + # If we are using a sentence transformer model + if "sentence" in embedding_model.lower() and similarity != "cosine": + logger.warning(f"You seem to be using a Sentence Transformer with the {similarity} function. " + f"We recommend using cosine instead. " + f"This can be set when initializing the DocumentStore") + elif "dpr" in embedding_model.lower() and similarity != "dot_product": + logger.warning(f"You seem to be using a DPR model with the {similarity} function. " + f"We recommend using dot_product instead. " + f"This can be set when initializing the DocumentStore") + elif model_format == "sentence_transformers": try: @@ -414,6 +432,11 @@ def __init__( else: device = "cpu" self.embedding_model = SentenceTransformer(embedding_model, device=device) + if document_store.similarity != "cosine": + logger.warning( + f"You are using a Sentence Transformer with the {document_store.similarity} function. " + f"We recommend using cosine instead. " + f"This can be set when initializing the DocumentStore") else: raise NotImplementedError diff --git a/tutorials/Tutorial4_FAQ_style_QA.py b/tutorials/Tutorial4_FAQ_style_QA.py index 5b6f88f459..5a7306392e 100755 --- a/tutorials/Tutorial4_FAQ_style_QA.py +++ b/tutorials/Tutorial4_FAQ_style_QA.py @@ -44,7 +44,8 @@ index="document", embedding_field="question_emb", embedding_dim=768, - excluded_meta_data=["question_emb"]) + excluded_meta_data=["question_emb"], + similarity="cosine") ### Create a Retriever using embeddings # Instead of retrieving via Elasticsearch's plain BM25, we want to use vector similarity of the questions (user question vs. FAQ ones).