-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make Elasticsearch configuration more flexible (#29)
- Loading branch information
Showing
5 changed files
with
137 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,85 +1,124 @@ | ||
from elasticsearch import Elasticsearch | ||
from elasticsearch_dsl import Search, Document as ESDoc, Text, connections | ||
from haystack.database.base import BaseDocumentStore | ||
|
||
from elasticsearch.helpers import scan | ||
|
||
class Document(ESDoc): | ||
name = Text() | ||
text = Text() | ||
tags = Text() | ||
|
||
class Index: | ||
name = "document" | ||
from haystack.database.base import BaseDocumentStore | ||
|
||
|
||
class ElasticsearchDocumentStore(BaseDocumentStore): | ||
def __init__(self, host="localhost", username="", password="", index="document"): | ||
def __init__( | ||
self, | ||
host="localhost", | ||
username="", | ||
password="", | ||
index="document", | ||
search_fields="text", | ||
text_field="text", | ||
name_field="name", | ||
doc_id_field="document_id", | ||
tag_fields=None, | ||
custom_mapping=None, | ||
): | ||
self.client = Elasticsearch(hosts=[{"host": host}], http_auth=(username, password)) | ||
self.connections = connections.create_connection(hosts=[{"host": host}], http_auth=(username, password)) | ||
Document.init() # create mapping if not exists. | ||
# if no custom_mapping is supplied, use the default mapping | ||
if not custom_mapping: | ||
custom_mapping = { | ||
"mappings": { | ||
"properties": { | ||
name_field: {"type": "text"}, | ||
text_field: {"type": "text"}, | ||
doc_id_field: {"type": "text"}, | ||
} | ||
} | ||
} | ||
# create an index if not exists | ||
self.client.indices.create(index=index, ignore=400, body=custom_mapping) | ||
self.index = index | ||
|
||
# configure mappings to ES fields that will be used for querying / displaying results | ||
if type(search_fields) == str: | ||
search_fields = [search_fields] | ||
self.search_fields = search_fields | ||
self.text_field = text_field | ||
self.name_field = name_field | ||
self.tag_fields = tag_fields | ||
self.doc_id_field = doc_id_field | ||
|
||
def get_document_by_id(self, id): | ||
query = {"filter": {"term": {"_id": id}}} | ||
result = self.client.search(index=self.index, body=query)["hits"]["hits"] | ||
if result: | ||
document = {"id": result["_id"], "name": result["name"], "text": result["text"]} | ||
document = { | ||
"id": result[self.doc_id_field], | ||
"name": result[self.name_field], | ||
"text": result[self.text_field], | ||
} | ||
else: | ||
document = None | ||
return document | ||
|
||
def get_document_ids_by_tags(self, tags): | ||
query = { | ||
"query": { | ||
"bool": { | ||
"should": [ | ||
{ | ||
"terms": { | ||
"tags": tags | ||
} | ||
} | ||
] | ||
} | ||
} | ||
} | ||
def get_document_by_name(self, name): | ||
query = {"filter": {"term": {self.name_field: name}}} | ||
result = self.client.search(index=self.index, body=query)["hits"]["hits"] | ||
documents = [] | ||
if result: | ||
document = { | ||
"id": result[self.doc_id_field], | ||
"name": result[self.name_field], | ||
"text": result[self.text_field], | ||
} | ||
else: | ||
document = None | ||
return document | ||
|
||
def get_document_ids_by_tags(self, tags): | ||
term_queries = [{"terms": {key: value}} for key, value in tags.items()] | ||
query = {"query": {"bool": {"must": term_queries}}} | ||
result = self.client.search(index=self.index, body=query, size=10000)["hits"]["hits"] | ||
doc_ids = [] | ||
for hit in result: | ||
documents.append({"id": hit["_id"], "name": hit["name"], "text": hit["text"]}) | ||
return documents | ||
doc_ids.append(hit["_id"]) | ||
return doc_ids | ||
|
||
def write_documents(self, documents): | ||
for doc in documents: | ||
d = Document( | ||
name=doc["name"], | ||
text=doc["text"], | ||
document_id=doc.get("document_id", None), | ||
tags=doc.get("tags", None), | ||
) | ||
d.save() | ||
for d in documents: | ||
self.client.index(index=self.index, body=d) | ||
|
||
def get_document_count(self): | ||
s = Search(using=self.client, index=self.index) | ||
return s.count() | ||
result = self.client.count() | ||
count = result["count"] | ||
return count | ||
|
||
def get_all_documents(self): | ||
search = Search(using=self.client, index=self.index).scan() | ||
result = scan(self.client, query={"query": {"match_all": {}}}, index=self.index) | ||
documents = [] | ||
for hit in search: | ||
for hit in result: | ||
documents.append( | ||
{ | ||
"id": hit.meta["id"], | ||
"name": hit["name"], | ||
"text": hit["text"], | ||
"id": hit["_source"][self.doc_id_field], | ||
"name": hit["_source"][self.name_field], | ||
"text": hit["_source"][self.text_field], | ||
} | ||
) | ||
return documents | ||
|
||
def query(self, query, top_k=10): | ||
search = Search(using=self.client, index=self.index).query("match", text=query)[:top_k].execute() | ||
def query(self, query, top_k=10, candidate_doc_ids=None): | ||
# TODO: | ||
# for now: we keep the current structure of candidate_doc_ids for compatibility with SQL documentstores | ||
# midterm: get rid of it and do filtering with tags directly in this query | ||
|
||
body = { | ||
"size": top_k, | ||
"query": { | ||
"bool": { | ||
"must": [{"multi_match": {"query": query, "type": "most_fields", "fields": self.search_fields}}] | ||
} | ||
}, | ||
} | ||
if candidate_doc_ids: | ||
body["query"]["bool"]["filter"] = [{"terms": {"_id": candidate_doc_ids}}] | ||
result = self.client.search(index=self.index, body=body)["hits"]["hits"] | ||
paragraphs = [] | ||
meta_data = [] | ||
for hit in search: | ||
paragraphs.append(hit["text"]) | ||
meta_data.append({"paragraph_id": hit.meta["id"], "document_id": hit["document_id"]}) | ||
for hit in result: | ||
paragraphs.append(hit["_source"][self.text_field]) | ||
meta_data.append({"paragraph_id": hit["_id"], "document_id": hit["_source"][self.doc_id_field]}) | ||
return paragraphs, meta_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import logging | ||
logger = logging.getLogger(__name__) | ||
|
||
class Finder: | ||
""" | ||
Finder ties together instances of the Reader and Retriever class. | ||
It provides an interface to predict top n answers for a given question. | ||
""" | ||
|
||
def __init__(self, reader, retriever): | ||
self.retriever = retriever | ||
self.reader = reader | ||
|
||
def get_answers(self, question, top_k_reader=1, top_k_retriever=10, filters=None): | ||
""" | ||
Get top k answers for a given question. | ||
:param question: the question string | ||
:param top_k_reader: number of answers returned by the reader | ||
:param top_k_retriever: number of text units to be retrieved | ||
:param filters: limit scope to documents having the given tags and their corresponding values. | ||
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...} | ||
:return: | ||
""" | ||
|
||
# 1) Optional: reduce the search space via document tags | ||
if filters: | ||
candidate_doc_ids = self.retriever.document_store.get_document_ids_by_tags(filters) | ||
else: | ||
candidate_doc_ids = None | ||
|
||
# 2) Apply retriever to get fast candidate paragraphs | ||
paragraphs, meta_data = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids) | ||
|
||
# 3) Apply reader to get granular answer(s) | ||
logger.info(f"Applying the reader now to look for the answer in detail ...") | ||
results = self.reader.predict(question=question, | ||
paragrahps=paragraphs, | ||
meta_data_paragraphs=meta_data, | ||
top_k=top_k_reader) | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,5 +6,4 @@ flask_sqlalchemy | |
pandas | ||
psycopg2-binary | ||
sklearn | ||
elasticsearch | ||
elasticsearch_dsl | ||
elasticsearch |