Skip to content

Commit

Permalink
Update LFQA with the latest LFQA seq2seq and retriever models (#2210)
Browse files Browse the repository at this point in the history
* Register BartEli5Converter for vblagoje/bart_lfqa model

* Update LFQA unit tests

* Update LFQA tutorials
  • Loading branch information
vblagoje authored Mar 8, 2022
1 parent 255226f commit 6c0094b
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 59 deletions.
22 changes: 12 additions & 10 deletions docs/_src/tutorials/tutorials/12.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Make sure you enable the GPU runtime to experience decent speed in this tutorial

# Install the latest master of Haystack
!pip install --upgrade pip
!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]
!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]
```


Expand Down Expand Up @@ -80,22 +80,24 @@ document_store.write_documents(dicts)

#### Retriever

**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`




```python
from haystack.nodes import EmbeddingRetriever
from haystack.nodes import DensePassageRetriever

retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert"
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
)

document_store.update_embeddings(retriever)
```

Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.
Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.


```python
Expand All @@ -111,13 +113,13 @@ print_documents(res, max_text_len=512)

Similar to previous Tutorials we now initalize our reader/generator.

Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)
Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)




```python
generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa")
```

### Pipeline
Expand All @@ -139,13 +141,13 @@ pipe = GenerativeQAPipeline(generator, retriever)

```python
pipe.run(
query="Why did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 1}}
query="How did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 3}}
)
```


```python
pipe.run(query="What kind of character does Arya Stark play?", params={"Retriever": {"top_k": 1}})
pipe.run(query="Why is Arya Stark an unusual character?", params={"Retriever": {"top_k": 3}})
```

## About us
Expand Down
3 changes: 2 additions & 1 deletion haystack/nodes/answer_generator/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ def __init__(
def _register_converters(cls, model_name_or_path: str, custom_converter: Optional[Callable]):
# init if empty
if not cls._model_input_converters:
cls._model_input_converters["yjernite/bart_eli5"] = _BartEli5Converter()
for c in ["yjernite/bart_eli5", "vblagoje/bart_lfqa"]:
cls._model_input_converters[c] = _BartEli5Converter()

# register user provided custom converter
if model_name_or_path and custom_converter:
Expand Down
4 changes: 2 additions & 2 deletions json-schemas/haystack-pipeline-1.2.0.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1307,12 +1307,12 @@
},
"preceding_context_len": {
"title": "Preceding Context Len",
"default": 3,
"default": 1,
"type": "integer"
},
"following_context_len": {
"title": "Following Context Len",
"default": 3,
"default": 1,
"type": "integer"
},
"remove_page_headers": {
Expand Down
12 changes: 10 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def question_generator():


@pytest.fixture(scope="function")
def eli5_generator():
return Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5", max_length=20)
def lfqa_generator(request):
return Seq2SeqGenerator(model_name_or_path=request.param, min_length=100, max_length=200)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -509,6 +509,14 @@ def get_retriever(retriever_type, document_store):
model_format="retribert",
use_gpu=False,
)
elif retriever_type == "dpr_lfqa":
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
use_gpu=False,
embed_title=True,
)
elif retriever_type == "elasticsearch":
retriever = ElasticsearchRetriever(document_store=document_store)
elif retriever_type == "es_filter_only":
Expand Down
11 changes: 6 additions & 5 deletions test/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,22 @@ def test_generator_pipeline(document_store, retriever, rag_generator):
@pytest.mark.slow
@pytest.mark.generator
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
@pytest.mark.parametrize("retriever", ["retribert", "dpr_lfqa"], indirect=True)
@pytest.mark.parametrize("lfqa_generator", ["yjernite/bart_eli5", "vblagoje/bart_lfqa"], indirect=True)
@pytest.mark.embedding_dim(128)
def test_lfqa_pipeline(document_store, retriever, eli5_generator):
def test_lfqa_pipeline(document_store, retriever, lfqa_generator):
# reuse existing DOCS but regenerate embeddings with retribert
docs: List[Document] = []
for idx, d in enumerate(DOCS_WITH_EMBEDDINGS):
docs.append(Document(d.content, str(idx)))
document_store.write_documents(docs)
document_store.update_embeddings(retriever)
query = "Tell me about Berlin?"
pipeline = GenerativeQAPipeline(retriever=retriever, generator=eli5_generator)
pipeline = GenerativeQAPipeline(generator=lfqa_generator, retriever=retriever)
output = pipeline.run(query=query, params={"top_k": 1})
answers = output["answers"]
assert len(answers) == 1
assert "Germany" in answers[0].answer
assert len(answers) == 1, answers
assert "Germany" in answers[0].answer, answers[0].answer


@pytest.mark.slow
Expand Down
50 changes: 22 additions & 28 deletions tutorials/Tutorial12_LFQA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"\n",
"# Install the latest master of Haystack\n",
"!pip install --upgrade pip\n",
"!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]"
"!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]"
]
},
{
Expand Down Expand Up @@ -146,7 +146,7 @@
"\n",
"#### Retriever\n",
"\n",
"**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n",
"We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n",
"\n"
]
},
Expand All @@ -161,10 +161,12 @@
},
"outputs": [],
"source": [
"from haystack.nodes import EmbeddingRetriever\n",
"from haystack.nodes import DensePassageRetriever\n",
"\n",
"retriever = EmbeddingRetriever(\n",
" document_store=document_store, embedding_model=\"yjernite/retribert-base-uncased\", model_format=\"retribert\"\n",
"retriever = DensePassageRetriever(\n",
" document_store=document_store,\n",
" query_embedding_model=\"vblagoje/dpr-question_encoder-single-lfqa-wiki\",\n",
" passage_embedding_model=\"vblagoje/dpr-ctx_encoder-single-lfqa-wiki\",\n",
")\n",
"\n",
"document_store.update_embeddings(retriever)"
Expand All @@ -176,25 +178,16 @@
"id": "sMlVEnJ2NkZZ"
},
"source": [
"Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."
"Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"id": "qpu-t9rndgpe"
},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "EOL while scanning string literal (<ipython-input-1-cc681f017dc5>, line 7)",
"output_type": "error",
"traceback": [
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-1-cc681f017dc5>\"\u001b[0;36m, line \u001b[0;32m7\u001b[0m\n\u001b[0;31m params={\"top_k_retriever=5\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m EOL while scanning string literal\n"
]
}
],
"outputs": [],
"source": [
"from haystack.utils import print_documents\n",
"from haystack.pipelines import DocumentSearchPipeline\n",
Expand All @@ -214,7 +207,7 @@
"\n",
"Similar to previous Tutorials we now initalize our reader/generator.\n",
"\n",
"Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)\n",
"Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)\n",
"\n"
]
},
Expand All @@ -226,7 +219,7 @@
},
"outputs": [],
"source": [
"generator = Seq2SeqGenerator(model_name_or_path=\"yjernite/bart_eli5\")"
"generator = Seq2SeqGenerator(model_name_or_path=\"vblagoje/bart_lfqa\")"
]
},
{
Expand Down Expand Up @@ -274,25 +267,26 @@
"outputs": [],
"source": [
"pipe.run(\n",
" query=\"Why did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 1}}\n",
" query=\"How did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 3}}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"pipe.run(query=\"Why is Arya Stark an unusual character?\", params={\"Retriever\": {\"top_k\": 3}})"
],
"metadata": {
"id": "zvHb8SvMblw9"
"id": "IfTP9BfFGOo6"
},
"outputs": [],
"source": [
"pipe.run(query=\"What kind of character does Arya Stark play?\", params={\"Retriever\": {\"top_k\": 1}})"
]
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
"collapsed": false,
"id": "i88KdOc2wUXQ"
},
"source": [
"## About us\n",
Expand Down Expand Up @@ -340,5 +334,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 0
}
24 changes: 13 additions & 11 deletions tutorials/Tutorial12_LFQA.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,20 @@ def tutorial12_lfqa():

"""
Initalize Retriever and Reader/Generator:
We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
"""

from haystack.nodes import EmbeddingRetriever
from haystack.nodes import DensePassageRetriever

retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert"
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
)

document_store.update_embeddings(retriever)

"""Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."""
"""Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."""

from haystack.utils import print_documents
from haystack.pipelines import DocumentSearchPipeline
Expand All @@ -59,10 +61,10 @@ def tutorial12_lfqa():

"""
Similar to previous Tutorials we now initalize our reader/generator.
Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)
Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)
"""

generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa")

"""
Pipeline:
Expand All @@ -78,14 +80,14 @@ def tutorial12_lfqa():

"""Voilà! Ask a question!"""

query_1 = "Why did Arya Stark's character get portrayed in a television adaptation?"
result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 1}})
query_1 = "How did Arya Stark's character get portrayed in a television adaptation?"
result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 3}})
print(f"Query: {query_1}")
print(f"Answer: {result_1['answers'][0]}")
print()

query_2 = "What kind of character does Arya Stark play?"
result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 1}})
query_2 = "Why is Arya Stark an unusual character?"
result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 3}})
print(f"Query: {query_2}")
print(f"Answer: {result_2['answers'][0]}")
print()
Expand Down

0 comments on commit 6c0094b

Please sign in to comment.