Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaysoni committed Dec 4, 2020
1 parent dadd0a5 commit 6880d9c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def get_document_store(document_store_type):
os.remove("haystack_test.db")
document_store = SQLDocumentStore(url="sqlite:///haystack_test.db")
elif document_store_type == "memory":
document_store = InMemoryDocumentStore(return_embedding=False)
document_store = InMemoryDocumentStore(return_embedding=True)
elif document_store_type == "elasticsearch":
# make sure we start from a fresh index
client = Elasticsearch()
Expand All @@ -279,7 +279,7 @@ def get_document_store(document_store_type):
os.remove("haystack_test_faiss.db")
document_store = FAISSDocumentStore(
sql_url="sqlite:///haystack_test_faiss.db",
return_embedding=False
return_embedding=True
)
return document_store
else:
Expand Down
5 changes: 3 additions & 2 deletions test/test_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,15 @@ def test_faiss_finding(document_store, embedding_retriever):
assert len(prediction.get('answers', [])) == 1


def test_faiss_pipeline(faiss_document_store, embedding_retriever):
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_pipeline(document_store, embedding_retriever):
documents = [
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
{"name": "name_4", "text": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
]
faiss_document_store.write_documents(documents)
document_store.write_documents(documents)
pipeline = Pipeline()
pipeline.add_node(component=embedding_retriever, name="FAISS", inputs=["Query"])
output = pipeline.run(query="How to test this?", top_k_retriever=3)
Expand Down

0 comments on commit 6880d9c

Please sign in to comment.