Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix return type of EmbeddingRetriever to numpy array #245

Merged
merged 1 commit into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions haystack/database/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional, Union, Dict, Any
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk, scan
import numpy as np

from haystack.database.base import BaseDocumentStore, Document

Expand Down Expand Up @@ -236,7 +237,7 @@ def query(
return documents

def query_by_embedding(self,
query_emb: List[float],
query_emb: np.array,
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None) -> List[Document]:
Expand All @@ -255,7 +256,7 @@ def query_by_embedding(self,
"script": {
"source": f"cosineSimilarity(params.query_vector,doc['{self.embedding_field}']) + 1.0",
"params": {
"query_vector": query_emb
"query_vector": query_emb.tolist()
}
}
}
Expand Down
11 changes: 6 additions & 5 deletions haystack/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str
top_k=top_k, index=index)
return documents

def embed(self, texts: Union[List[str], str]) -> List[List[float]]:
def embed(self, texts: Union[List[str], str]) -> List[np.array]:
"""
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
:param texts: texts to embed
Expand All @@ -259,13 +259,14 @@ def embed(self, texts: Union[List[str], str]) -> List[List[float]]:
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"

if self.model_format == "farm" or self.model_format == "transformers":
res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
emb = [list(r["vec"]) for r in res] #cast from numpy
emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
emb = [(r["vec"]) for r in emb]
elif self.model_format == "sentence_transformers":
# text is single string, sentence-transformers needs a list of strings
# get back list of numpy embedding vectors
res = self.embedding_model.encode(texts) # type: ignore
emb = [list(r.astype('float64')) for r in res] #cast from numpy
emb = self.embedding_model.encode(texts) # type: ignore
# cast to float64 as float32 can cause trouble when serializing for ES
emb = [(r.astype('float64')) for r in emb]
return emb

def embed_queries(self, texts: List[str]) -> List[np.array]:
Expand Down