Skip to content

Commit

Permalink
Add FAISS reset after fixture teardown
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaysoni committed Dec 4, 2020
1 parent 67e6023 commit dadd0a5
Showing 1 changed file with 42 additions and 38 deletions.
80 changes: 42 additions & 38 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,6 @@ def no_answer_prediction(no_answer_reader, test_docs_xs):
return prediction


@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql"])
def document_store_with_docs(request, test_docs_xs):
document_store = get_document_store(request.param)
document_store.write_documents(test_docs_xs)
return document_store


@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql"])
def document_store(request, test_docs_xs):
return get_document_store(request.param)


@pytest.fixture(params=["es_filter_only", "elasticsearch", "dpr", "embedding", "tfidf"])
def retriever(request, document_store):
return get_retriever(request.param, document_store)
Expand All @@ -232,6 +220,48 @@ def retriever_with_docs(request, document_store_with_docs):
return get_retriever(request.param, document_store_with_docs)


def get_retriever(retriever_type, document_store):

if retriever_type == "dpr":
retriever = DensePassageRetriever(document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False, embed_title=True)
elif retriever_type == "tfidf":
return TfidfRetriever(document_store=document_store)
elif retriever_type == "embedding":
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="deepset/sentence_bert",
use_gpu=False
)
elif retriever_type == "elasticsearch":
retriever = ElasticsearchRetriever(document_store=document_store)
elif retriever_type == "es_filter_only":
retriever = ElasticsearchFilterOnlyRetriever(document_store=document_store)
else:
raise Exception(f"No retriever fixture for '{retriever_type}'")

return retriever


@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql"])
def document_store_with_docs(request, test_docs_xs):
document_store = get_document_store(request.param)
document_store.write_documents(test_docs_xs)
yield document_store
if request.param == "faiss":
document_store.faiss_index.reset()


@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql"])
def document_store(request, test_docs_xs):
document_store = get_document_store(request.param)
yield document_store
if request.param == "faiss":
document_store.faiss_index.reset()


def get_document_store(document_store_type):
if document_store_type == "sql":
if os.path.exists("haystack_test.db"):
Expand All @@ -251,34 +281,8 @@ def get_document_store(document_store_type):
sql_url="sqlite:///haystack_test_faiss.db",
return_embedding=False
)
document_store.faiss_index.reset()
return document_store
else:
raise Exception(f"No document store fixture for '{document_store_type}'")

return document_store


def get_retriever(retriever_type, document_store):

if retriever_type == "dpr":
retriever = DensePassageRetriever(document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False, embed_title=True)
elif retriever_type == "tfidf":
return TfidfRetriever(document_store=document_store)
elif retriever_type == "embedding":
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="deepset/sentence_bert",
use_gpu=False
)
elif retriever_type == "elasticsearch":
retriever = ElasticsearchRetriever(document_store=document_store)
elif retriever_type == "es_filter_only":
retriever = ElasticsearchFilterOnlyRetriever(document_store=document_store)
else:
raise Exception(f"No retriever fixture for '{retriever_type}'")

return retriever

0 comments on commit dadd0a5

Please sign in to comment.