Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scale dot product into probabilities #667

Merged
merged 9 commits into from
Dec 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions docs/_src/usage/usage/retriever.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ Indexing using DPR is comparatively expensive in terms of required computation s
The embeddings that are created in this step can be stored in FAISS, a database optimized for vector similarity.
DPR can also work with the ElasticsearchDocumentStore or the InMemoryDocumentStore.

<div class="recommendation">

**Tip**

When using DPR, it is recommended that you use the dot product similarity function since that is how it is trained.
tholor marked this conversation as resolved.
Show resolved Hide resolved
To do so, simply provide `similarity="dot_product"` when initializing the DocumentStore
as is done in the code example below.

</div>

There are two design decisions that have made DPR particularly performant.


Expand All @@ -136,7 +146,7 @@ If you’d like to learn how to set up a DPR based system, have a look at our tu
### Initialisation

```python
document_store = FAISSDocumentStore()
document_store = FAISSDocumentStore(similarity="dot_product")
...
retriever = DensePassageRetriever(
document_store=document_store,
Expand All @@ -161,10 +171,20 @@ They are particular suited to cases where your query input is similar in style t
i.e. when you are searching for most similar documents.
This is not inherently suited to query based search where the length, language and format of the query usually significantly differs from the searched for text.

<div class="recommendation">

**Tip**

When using Sentence Transformer models, we recommend that you use a cosine similarity function.
To do so, simply provide `similarity="cosine"` when initializing the DocumentStore
as is done in the code example below.

</div>

### Initialisation

```python
document_store = ElasticsearchDocumentStore()
document_store = ElasticsearchDocumentStore(similarity="cosine")
...
retriever = EmbeddingRetriever(document_store=document_store,
embedding_model="deepset/sentence_bert")
Expand Down
1 change: 1 addition & 0 deletions haystack/document_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class BaseDocumentStore(ABC):
"""
index: Optional[str]
label_index: Optional[str]
similarity: Optional[str]

@abstractmethod
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
Expand Down
6 changes: 5 additions & 1 deletion haystack/document_store/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(

self.update_existing_documents = update_existing_documents
self.refresh_type = refresh_type
self.similarity = similarity
if similarity == "cosine":
self.similarity_fn_name = "cosineSimilarity"
elif similarity == "dot_product":
Expand Down Expand Up @@ -596,7 +597,10 @@ def _convert_es_hit_to_document(
if score:
if adapt_score_for_embedding:
score -= 1000
probability = (score + 1) / 2 # scaling probability from cosine similarity
if self.similarity == "cosine":
probability = (score + 1) / 2 # scaling probability from cosine similarity
elif self.similarity == "dot_product":
probability = float(expit(np.asarray(score / 100))) # scaling probability from dot product
else:
probability = float(expit(np.asarray(score / 8))) # scaling probability from TFIDF/BM25
else:
Expand Down
18 changes: 16 additions & 2 deletions haystack/document_store/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from haystack.document_store.sql import SQLDocumentStore
from haystack.retriever.base import BaseRetriever

from scipy.special import expit

if platform != 'win32' and platform != 'cygwin':
import faiss
else:
Expand Down Expand Up @@ -39,6 +41,7 @@ def __init__(
return_embedding: bool = False,
update_existing_documents: bool = False,
index: str = "document",
similarity: str = "dot_product",
**kwargs,
):
"""
Expand Down Expand Up @@ -82,6 +85,11 @@ def __init__(

self.index_buffer_size = index_buffer_size
self.return_embedding = return_embedding
if similarity == "dot_product":
self.similarity = similarity
else:
raise ValueError("The FAISS document store can currently only support dot_product similarity. "
"Please set similarity=\"dot_product\"")
super().__init__(
url=sql_url,
update_existing_documents=update_existing_documents,
Expand Down Expand Up @@ -116,7 +124,8 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O

# doc + metadata index
index = index or self.index
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
field_map = self._create_document_field_map()
document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]

add_vectors = False if document_objects[0].embedding is None else True

Expand All @@ -142,6 +151,11 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O

super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index)

def _create_document_field_map(self) -> Dict:
return {
self.index: "embedding",
}

def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
Expand Down Expand Up @@ -273,7 +287,7 @@ def query_by_embedding(self,
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
for doc in documents:
doc.score = scores_for_vector_ids[doc.meta["vector_id"]]
doc.probability = (doc.score + 1) / 2
doc.probability = float(expit(np.asarray(doc.score / 100)))
if return_embedding is True:
doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"]))

Expand Down
29 changes: 22 additions & 7 deletions haystack/document_store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from haystack.preprocessor.utils import eval_data_from_file
from haystack.retriever.base import BaseRetriever

from scipy.spatial.distance import cosine

import logging
logger = logging.getLogger(__name__)

Expand All @@ -17,13 +19,14 @@ class InMemoryDocumentStore(BaseDocumentStore):
In-memory document store
"""

def __init__(self, embedding_field: Optional[str] = "embedding", return_embedding: bool = False):
def __init__(self, embedding_field: Optional[str] = "embedding", return_embedding: bool = False, similarity="dot_product"):
self.indexes: Dict[str, Dict] = defaultdict(dict)
self.index: str = "document"
self.label_index: str = "label"
self.embedding_field: str = embedding_field if embedding_field is not None else "embedding"
self.embedding_dim: int = 768
self.return_embedding: bool = return_embedding
self.similarity: str = similarity

def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
"""
Expand All @@ -41,11 +44,17 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
"""
index = index or self.index

documents_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
field_map = self._create_document_field_map()
documents_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]

for document in documents_objects:
self.indexes[index][document.id] = document

def _create_document_field_map(self):
return {
self.embedding_field: "embedding",
}

def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[str] = None):
"""Write annotation labels into document store."""
index = index or self.label_index
Expand Down Expand Up @@ -106,18 +115,24 @@ def query_by_embedding(self,

candidate_docs = []
for idx, doc in self.indexes[index].items():
curr_meta = deepcopy(doc.meta)
new_document = Document(
id=doc.id,
text=doc.text,
meta=deepcopy(doc.meta)
meta=curr_meta,
embedding=doc.embedding
)
new_document.embedding = doc.embedding if return_embedding is True else None
score = dot(query_emb, doc.embedding) / (
norm(query_emb) * norm(doc.embedding)
)

if self.similarity == "dot_product":
score = dot(query_emb, doc.embedding) / (
norm(query_emb) * norm(doc.embedding)
)
elif self.similarity == "cosine":
# cosine similarity score = 1 - cosine distance
score = 1 - cosine(query_emb, doc.embedding)
new_document.score = score
new_document.probability = (score + 1) / 2

candidate_docs.append(new_document)

return sorted(candidate_docs, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)[0:top_k]
Expand Down
2 changes: 2 additions & 0 deletions haystack/document_store/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
self.index = index
self.label_index = label_index
self.update_existing_documents = update_existing_documents
if getattr(self, "similarity", None) is None:
self.similarity = None

def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
Expand Down
23 changes: 23 additions & 0 deletions haystack/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tqdm import tqdm

from haystack.document_store.base import BaseDocumentStore
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack import Document
from haystack.retriever.base import BaseRetriever

Expand Down Expand Up @@ -86,6 +87,11 @@ def __init__(self,
self.max_seq_len_passage = max_seq_len_passage
self.max_seq_len_query = max_seq_len_query

if document_store.similarity != "dot_product":
logger.warning(f"You are using a Dense Passage Retriever model with the {document_store.similarity} function. "
"We recommend you use dot_product instead. "
"This can be set when initializing the DocumentStore")

if use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")
else:
Expand Down Expand Up @@ -399,6 +405,18 @@ def __init__(
embedding_model, task_type="embeddings", extraction_strategy=self.pooling_strategy,
extraction_layer=self.emb_extraction_layer, gpu=use_gpu, batch_size=4, max_seq_len=512, num_processes=0
)
# Check that document_store has the right similarity function
similarity = document_store.similarity
# If we are using a sentence transformer model
if "sentence" in embedding_model.lower() and similarity != "cosine":
logger.warning(f"You seem to be using a Sentence Transformer with the {similarity} function. "
f"We recommend using cosine instead. "
f"This can be set when initializing the DocumentStore")
elif "dpr" in embedding_model.lower() and similarity != "dot_product":
logger.warning(f"You seem to be using a DPR model with the {similarity} function. "
f"We recommend using dot_product instead. "
f"This can be set when initializing the DocumentStore")


elif model_format == "sentence_transformers":
try:
Expand All @@ -414,6 +432,11 @@ def __init__(
else:
device = "cpu"
self.embedding_model = SentenceTransformer(embedding_model, device=device)
if document_store.similarity != "cosine":
logger.warning(
f"You are using a Sentence Transformer with the {document_store.similarity} function. "
f"We recommend using cosine instead. "
f"This can be set when initializing the DocumentStore")
else:
raise NotImplementedError

Expand Down
3 changes: 2 additions & 1 deletion tutorials/Tutorial4_FAQ_style_QA.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
index="document",
embedding_field="question_emb",
embedding_dim=768,
excluded_meta_data=["question_emb"])
excluded_meta_data=["question_emb"],
similarity="cosine")

### Create a Retriever using embeddings
# Instead of retrieving via Elasticsearch's plain BM25, we want to use vector similarity of the questions (user question vs. FAQ ones).
Expand Down