Skip to content

Commit

Permalink
Use ElasticsearchDocumentStore.get_all_documents in `ElasticsearchF…
Browse files Browse the repository at this point in the history
…ilterOnlyRetriever.retrieve` (#2151)

* use get_all_documents in ElasticsearchFilterOnlyRetriever.retrieve

* Update Documentation & Code Style

* add test case for es_filter_only retriever

* Update Documentation & Code Style

* fix test by adding empty string for query

* Update Documentation & Code Style

* add explicit name of argument "query"

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Julian Risch <[email protected]>
  • Loading branch information
3 people authored Apr 25, 2022
1 parent 25475a6 commit c401e86
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
4 changes: 2 additions & 2 deletions docs/_src/api/api/retriever.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ that are most relevant to the query.

**Arguments**:

- `query`: The query
- `query`: Has no effect, can pass in empty string
- `filters`: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
- `top_k`: How many documents to return per query.
- `top_k`: Has no effect, pass in any int or None
- `index`: The name of the index in the DocumentStore from which to retrieve documents
- `headers`: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
Expand Down
10 changes: 3 additions & 7 deletions haystack/nodes/retriever/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,16 @@ def retrieve(
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the query.
:param query: The query
:param query: Has no effect, can pass in empty string
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
:param top_k: How many documents to return per query.
:param top_k: Has no effect, pass in any int or None
:param index: The name of the index in the DocumentStore from which to retrieve documents
:param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
"""
if top_k is None:
top_k = self.top_k
if index is None:
index = self.document_store.index
documents = self.document_store.query(
query=None, filters=filters, top_k=top_k, custom_query=self.custom_query, index=index, headers=headers
)
documents = self.document_store.get_all_documents(filters=filters, index=index, headers=headers)
return documents


Expand Down
22 changes: 22 additions & 0 deletions test/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,25 @@ def test_embeddings_encoder_of_embedding_retriever_should_warn_about_model_forma
"You may need to set 'model_format='sentence_transformers' to ensure correct loading of model."
in caplog.text
)


@pytest.mark.parametrize("retriever", ["es_filter_only"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
def test_es_filter_only(document_store, retriever):
docs = [
Document(content="Doc1", meta={"f1": "0"}),
Document(content="Doc2", meta={"f1": "0"}),
Document(content="Doc3", meta={"f1": "0"}),
Document(content="Doc4", meta={"f1": "0"}),
Document(content="Doc5", meta={"f1": "0"}),
Document(content="Doc6", meta={"f1": "0"}),
Document(content="Doc7", meta={"f1": "1"}),
Document(content="Doc8", meta={"f1": "0"}),
Document(content="Doc9", meta={"f1": "0"}),
Document(content="Doc10", meta={"f1": "0"}),
Document(content="Doc11", meta={"f1": "0"}),
Document(content="Doc12", meta={"f1": "0"}),
]
document_store.write_documents(docs)
retrieved_docs = retriever.retrieve(query="", filters={"f1": ["0"]})
assert len(retrieved_docs) == 11

0 comments on commit c401e86

Please sign in to comment.