Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LFQA with the latest LFQA seq2seq and retriever models #2210

Merged
merged 5 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions docs/_src/tutorials/tutorials/12.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Make sure you enable the GPU runtime to experience decent speed in this tutorial

# Install the latest master of Haystack
!pip install --upgrade pip
!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]
!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]
```


Expand Down Expand Up @@ -80,22 +80,24 @@ document_store.write_documents(dicts)

#### Retriever

**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`




```python
from haystack.nodes import EmbeddingRetriever
from haystack.nodes import DensePassageRetriever

retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert"
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
)

document_store.update_embeddings(retriever)
```

Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.
Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.


```python
Expand All @@ -111,13 +113,13 @@ print_documents(res, max_text_len=512)

Similar to previous Tutorials we now initalize our reader/generator.

Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)
Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)




```python
generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa")
```

### Pipeline
Expand All @@ -139,13 +141,13 @@ pipe = GenerativeQAPipeline(generator, retriever)

```python
pipe.run(
query="Why did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 1}}
query="How did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 3}}
)
```


```python
pipe.run(query="What kind of character does Arya Stark play?", params={"Retriever": {"top_k": 1}})
pipe.run(query="Why is Arya Stark an unusual character?", params={"Retriever": {"top_k": 3}})
```

## About us
Expand Down
3 changes: 2 additions & 1 deletion haystack/nodes/answer_generator/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ def __init__(
def _register_converters(cls, model_name_or_path: str, custom_converter: Optional[Callable]):
# init if empty
if not cls._model_input_converters:
cls._model_input_converters["yjernite/bart_eli5"] = _BartEli5Converter()
for c in ["yjernite/bart_eli5", "vblagoje/bart_lfqa"]:
cls._model_input_converters[c] = _BartEli5Converter()

# register user provided custom converter
if model_name_or_path and custom_converter:
Expand Down
4 changes: 2 additions & 2 deletions json-schemas/haystack-pipeline-1.2.0.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1307,12 +1307,12 @@
},
"preceding_context_len": {
"title": "Preceding Context Len",
"default": 3,
"default": 1,
"type": "integer"
},
"following_context_len": {
"title": "Following Context Len",
"default": 3,
"default": 1,
"type": "integer"
},
"remove_page_headers": {
Expand Down
12 changes: 10 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def question_generator():


@pytest.fixture(scope="function")
def eli5_generator():
return Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5", max_length=20)
def lfqa_generator(request):
return Seq2SeqGenerator(model_name_or_path=request.param, min_length=100, max_length=200)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -509,6 +509,14 @@ def get_retriever(retriever_type, document_store):
model_format="retribert",
use_gpu=False,
)
elif retriever_type == "dpr_lfqa":
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
use_gpu=False,
embed_title=True,
)
elif retriever_type == "elasticsearch":
retriever = ElasticsearchRetriever(document_store=document_store)
elif retriever_type == "es_filter_only":
Expand Down
11 changes: 6 additions & 5 deletions test/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,22 @@ def test_generator_pipeline(document_store, retriever, rag_generator):
@pytest.mark.slow
@pytest.mark.generator
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
@pytest.mark.parametrize("retriever", ["retribert", "dpr_lfqa"], indirect=True)
@pytest.mark.parametrize("lfqa_generator", ["yjernite/bart_eli5", "vblagoje/bart_lfqa"], indirect=True)
@pytest.mark.embedding_dim(128)
def test_lfqa_pipeline(document_store, retriever, eli5_generator):
def test_lfqa_pipeline(document_store, retriever, lfqa_generator):
# reuse existing DOCS but regenerate embeddings with retribert
docs: List[Document] = []
for idx, d in enumerate(DOCS_WITH_EMBEDDINGS):
docs.append(Document(d.content, str(idx)))
document_store.write_documents(docs)
document_store.update_embeddings(retriever)
query = "Tell me about Berlin?"
pipeline = GenerativeQAPipeline(retriever=retriever, generator=eli5_generator)
pipeline = GenerativeQAPipeline(generator=lfqa_generator, retriever=retriever)
output = pipeline.run(query=query, params={"top_k": 1})
answers = output["answers"]
assert len(answers) == 1
assert "Germany" in answers[0].answer
assert len(answers) == 1, answers
assert "Germany" in answers[0].answer, answers[0].answer


@pytest.mark.slow
Expand Down
50 changes: 22 additions & 28 deletions tutorials/Tutorial12_LFQA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"\n",
"# Install the latest master of Haystack\n",
"!pip install --upgrade pip\n",
"!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]"
"!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]"
]
},
{
Expand Down Expand Up @@ -146,7 +146,7 @@
"\n",
"#### Retriever\n",
"\n",
"**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n",
"We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n",
"\n"
]
},
Expand All @@ -161,10 +161,12 @@
},
"outputs": [],
"source": [
"from haystack.nodes import EmbeddingRetriever\n",
"from haystack.nodes import DensePassageRetriever\n",
"\n",
"retriever = EmbeddingRetriever(\n",
" document_store=document_store, embedding_model=\"yjernite/retribert-base-uncased\", model_format=\"retribert\"\n",
"retriever = DensePassageRetriever(\n",
" document_store=document_store,\n",
" query_embedding_model=\"vblagoje/dpr-question_encoder-single-lfqa-wiki\",\n",
" passage_embedding_model=\"vblagoje/dpr-ctx_encoder-single-lfqa-wiki\",\n",
")\n",
"\n",
"document_store.update_embeddings(retriever)"
Expand All @@ -176,25 +178,16 @@
"id": "sMlVEnJ2NkZZ"
},
"source": [
"Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."
"Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"id": "qpu-t9rndgpe"
},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "EOL while scanning string literal (<ipython-input-1-cc681f017dc5>, line 7)",
"output_type": "error",
"traceback": [
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-1-cc681f017dc5>\"\u001b[0;36m, line \u001b[0;32m7\u001b[0m\n\u001b[0;31m params={\"top_k_retriever=5\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m EOL while scanning string literal\n"
]
}
],
"outputs": [],
"source": [
"from haystack.utils import print_documents\n",
"from haystack.pipelines import DocumentSearchPipeline\n",
Expand All @@ -214,7 +207,7 @@
"\n",
"Similar to previous Tutorials we now initalize our reader/generator.\n",
"\n",
"Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)\n",
"Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)\n",
"\n"
]
},
Expand All @@ -226,7 +219,7 @@
},
"outputs": [],
"source": [
"generator = Seq2SeqGenerator(model_name_or_path=\"yjernite/bart_eli5\")"
"generator = Seq2SeqGenerator(model_name_or_path=\"vblagoje/bart_lfqa\")"
]
},
{
Expand Down Expand Up @@ -274,25 +267,26 @@
"outputs": [],
"source": [
"pipe.run(\n",
" query=\"Why did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 1}}\n",
" query=\"How did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 3}}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"pipe.run(query=\"Why is Arya Stark an unusual character?\", params={\"Retriever\": {\"top_k\": 3}})"
],
"metadata": {
"id": "zvHb8SvMblw9"
"id": "IfTP9BfFGOo6"
},
"outputs": [],
"source": [
"pipe.run(query=\"What kind of character does Arya Stark play?\", params={\"Retriever\": {\"top_k\": 1}})"
]
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
"collapsed": false,
"id": "i88KdOc2wUXQ"
},
"source": [
"## About us\n",
Expand Down Expand Up @@ -340,5 +334,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 0
}
24 changes: 13 additions & 11 deletions tutorials/Tutorial12_LFQA.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,20 @@ def tutorial12_lfqa():

"""
Initalize Retriever and Reader/Generator:
We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
"""

from haystack.nodes import EmbeddingRetriever
from haystack.nodes import DensePassageRetriever

retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert"
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
)

document_store.update_embeddings(retriever)

"""Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."""
"""Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."""

from haystack.utils import print_documents
from haystack.pipelines import DocumentSearchPipeline
Expand All @@ -59,10 +61,10 @@ def tutorial12_lfqa():

"""
Similar to previous Tutorials we now initalize our reader/generator.
Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)
Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)
"""

generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa")

"""
Pipeline:
Expand All @@ -78,14 +80,14 @@ def tutorial12_lfqa():

"""Voilà! Ask a question!"""

query_1 = "Why did Arya Stark's character get portrayed in a television adaptation?"
result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 1}})
query_1 = "How did Arya Stark's character get portrayed in a television adaptation?"
result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 3}})
print(f"Query: {query_1}")
print(f"Answer: {result_1['answers'][0]}")
print()

query_2 = "What kind of character does Arya Stark play?"
result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 1}})
query_2 = "Why is Arya Stark an unusual character?"
result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 3}})
print(f"Query: {query_2}")
print(f"Answer: {result_2['answers'][0]}")
print()
Expand Down