diff --git a/docs/_src/tutorials/tutorials/12.md b/docs/_src/tutorials/tutorials/12.md index 605c8cfe7e..5c8ace1e25 100644 --- a/docs/_src/tutorials/tutorials/12.md +++ b/docs/_src/tutorials/tutorials/12.md @@ -32,7 +32,7 @@ Make sure you enable the GPU runtime to experience decent speed in this tutorial # Install the latest master of Haystack !pip install --upgrade pip -!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss] +!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss] ``` @@ -80,22 +80,24 @@ document_store.write_documents(dicts) #### Retriever -**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore` +We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore` ```python -from haystack.nodes import EmbeddingRetriever +from haystack.nodes import DensePassageRetriever -retriever = EmbeddingRetriever( - document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert" +retriever = DensePassageRetriever( + document_store=document_store, + query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki", + passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki", ) document_store.update_embeddings(retriever) ``` -Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents. +Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents. ```python @@ -111,13 +113,13 @@ print_documents(res, max_text_len=512) Similar to previous Tutorials we now initalize our reader/generator. -Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5) +Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa) ```python -generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5") +generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa") ``` ### Pipeline @@ -139,13 +141,13 @@ pipe = GenerativeQAPipeline(generator, retriever) ```python pipe.run( - query="Why did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 1}} + query="How did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 3}} ) ``` ```python -pipe.run(query="What kind of character does Arya Stark play?", params={"Retriever": {"top_k": 1}}) +pipe.run(query="Why is Arya Stark an unusual character?", params={"Retriever": {"top_k": 3}}) ``` ## About us diff --git a/haystack/nodes/answer_generator/transformers.py b/haystack/nodes/answer_generator/transformers.py index b2823a7fbb..874897f336 100644 --- a/haystack/nodes/answer_generator/transformers.py +++ b/haystack/nodes/answer_generator/transformers.py @@ -386,7 +386,8 @@ def __init__( def _register_converters(cls, model_name_or_path: str, custom_converter: Optional[Callable]): # init if empty if not cls._model_input_converters: - cls._model_input_converters["yjernite/bart_eli5"] = _BartEli5Converter() + for c in ["yjernite/bart_eli5", "vblagoje/bart_lfqa"]: + cls._model_input_converters[c] = _BartEli5Converter() # register user provided custom converter if model_name_or_path and custom_converter: diff --git a/json-schemas/haystack-pipeline-1.2.0.schema.json b/json-schemas/haystack-pipeline-1.2.0.schema.json index d425ee0986..eb36978cd7 100644 --- a/json-schemas/haystack-pipeline-1.2.0.schema.json +++ b/json-schemas/haystack-pipeline-1.2.0.schema.json @@ -1307,12 +1307,12 @@ }, "preceding_context_len": { "title": "Preceding Context Len", - "default": 3, + "default": 1, "type": "integer" }, "following_context_len": { "title": "Following Context Len", - "default": 3, + "default": 1, "type": "integer" }, "remove_page_headers": { diff --git a/test/conftest.py b/test/conftest.py index 26f6b8617b..2efb852c92 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -304,8 +304,8 @@ def question_generator(): @pytest.fixture(scope="function") -def eli5_generator(): - return Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5", max_length=20) +def lfqa_generator(request): + return Seq2SeqGenerator(model_name_or_path=request.param, min_length=100, max_length=200) @pytest.fixture(scope="function") @@ -509,6 +509,14 @@ def get_retriever(retriever_type, document_store): model_format="retribert", use_gpu=False, ) + elif retriever_type == "dpr_lfqa": + retriever = DensePassageRetriever( + document_store=document_store, + query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki", + passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki", + use_gpu=False, + embed_title=True, + ) elif retriever_type == "elasticsearch": retriever = ElasticsearchRetriever(document_store=document_store) elif retriever_type == "es_filter_only": diff --git a/test/test_generator.py b/test/test_generator.py index 5c75acf61d..f5a217b225 100644 --- a/test/test_generator.py +++ b/test/test_generator.py @@ -60,9 +60,10 @@ def test_generator_pipeline(document_store, retriever, rag_generator): @pytest.mark.slow @pytest.mark.generator @pytest.mark.parametrize("document_store", ["memory"], indirect=True) -@pytest.mark.parametrize("retriever", ["retribert"], indirect=True) +@pytest.mark.parametrize("retriever", ["retribert", "dpr_lfqa"], indirect=True) +@pytest.mark.parametrize("lfqa_generator", ["yjernite/bart_eli5", "vblagoje/bart_lfqa"], indirect=True) @pytest.mark.embedding_dim(128) -def test_lfqa_pipeline(document_store, retriever, eli5_generator): +def test_lfqa_pipeline(document_store, retriever, lfqa_generator): # reuse existing DOCS but regenerate embeddings with retribert docs: List[Document] = [] for idx, d in enumerate(DOCS_WITH_EMBEDDINGS): @@ -70,11 +71,11 @@ def test_lfqa_pipeline(document_store, retriever, eli5_generator): document_store.write_documents(docs) document_store.update_embeddings(retriever) query = "Tell me about Berlin?" - pipeline = GenerativeQAPipeline(retriever=retriever, generator=eli5_generator) + pipeline = GenerativeQAPipeline(generator=lfqa_generator, retriever=retriever) output = pipeline.run(query=query, params={"top_k": 1}) answers = output["answers"] - assert len(answers) == 1 - assert "Germany" in answers[0].answer + assert len(answers) == 1, answers + assert "Germany" in answers[0].answer, answers[0].answer @pytest.mark.slow diff --git a/tutorials/Tutorial12_LFQA.ipynb b/tutorials/Tutorial12_LFQA.ipynb index 87ee1b32eb..30cb134bb6 100644 --- a/tutorials/Tutorial12_LFQA.ipynb +++ b/tutorials/Tutorial12_LFQA.ipynb @@ -51,7 +51,7 @@ "\n", "# Install the latest master of Haystack\n", "!pip install --upgrade pip\n", - "!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]" + "!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]" ] }, { @@ -146,7 +146,7 @@ "\n", "#### Retriever\n", "\n", - "**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n", + "We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n", "\n" ] }, @@ -161,10 +161,12 @@ }, "outputs": [], "source": [ - "from haystack.nodes import EmbeddingRetriever\n", + "from haystack.nodes import DensePassageRetriever\n", "\n", - "retriever = EmbeddingRetriever(\n", - " document_store=document_store, embedding_model=\"yjernite/retribert-base-uncased\", model_format=\"retribert\"\n", + "retriever = DensePassageRetriever(\n", + " document_store=document_store,\n", + " query_embedding_model=\"vblagoje/dpr-question_encoder-single-lfqa-wiki\",\n", + " passage_embedding_model=\"vblagoje/dpr-ctx_encoder-single-lfqa-wiki\",\n", ")\n", "\n", "document_store.update_embeddings(retriever)" @@ -176,25 +178,16 @@ "id": "sMlVEnJ2NkZZ" }, "source": [ - "Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents." + "Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "id": "qpu-t9rndgpe" }, - "outputs": [ - { - "ename": "SyntaxError", - "evalue": "EOL while scanning string literal (, line 7)", - "output_type": "error", - "traceback": [ - "\u001b[0;36m File \u001b[0;32m\"\"\u001b[0;36m, line \u001b[0;32m7\u001b[0m\n\u001b[0;31m params={\"top_k_retriever=5\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m EOL while scanning string literal\n" - ] - } - ], + "outputs": [], "source": [ "from haystack.utils import print_documents\n", "from haystack.pipelines import DocumentSearchPipeline\n", @@ -214,7 +207,7 @@ "\n", "Similar to previous Tutorials we now initalize our reader/generator.\n", "\n", - "Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)\n", + "Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)\n", "\n" ] }, @@ -226,7 +219,7 @@ }, "outputs": [], "source": [ - "generator = Seq2SeqGenerator(model_name_or_path=\"yjernite/bart_eli5\")" + "generator = Seq2SeqGenerator(model_name_or_path=\"vblagoje/bart_lfqa\")" ] }, { @@ -274,25 +267,26 @@ "outputs": [], "source": [ "pipe.run(\n", - " query=\"Why did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 1}}\n", + " query=\"How did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 3}}\n", ")" ] }, { "cell_type": "code", - "execution_count": null, + "source": [ + "pipe.run(query=\"Why is Arya Stark an unusual character?\", params={\"Retriever\": {\"top_k\": 3}})" + ], "metadata": { - "id": "zvHb8SvMblw9" + "id": "IfTP9BfFGOo6" }, - "outputs": [], - "source": [ - "pipe.run(query=\"What kind of character does Arya Stark play?\", params={\"Retriever\": {\"top_k\": 1}})" - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", "metadata": { - "collapsed": false + "collapsed": false, + "id": "i88KdOc2wUXQ" }, "source": [ "## About us\n", @@ -340,5 +334,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 0 } diff --git a/tutorials/Tutorial12_LFQA.py b/tutorials/Tutorial12_LFQA.py index 89b4d1457d..5197d33d47 100644 --- a/tutorials/Tutorial12_LFQA.py +++ b/tutorials/Tutorial12_LFQA.py @@ -37,18 +37,20 @@ def tutorial12_lfqa(): """ Initalize Retriever and Reader/Generator: - We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore` + We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore` """ - from haystack.nodes import EmbeddingRetriever + from haystack.nodes import DensePassageRetriever - retriever = EmbeddingRetriever( - document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert" + retriever = DensePassageRetriever( + document_store=document_store, + query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki", + passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki", ) document_store.update_embeddings(retriever) - """Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.""" + """Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.""" from haystack.utils import print_documents from haystack.pipelines import DocumentSearchPipeline @@ -59,10 +61,10 @@ def tutorial12_lfqa(): """ Similar to previous Tutorials we now initalize our reader/generator. - Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5) + Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa) """ - generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5") + generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa") """ Pipeline: @@ -78,14 +80,14 @@ def tutorial12_lfqa(): """VoilĂ ! Ask a question!""" - query_1 = "Why did Arya Stark's character get portrayed in a television adaptation?" - result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 1}}) + query_1 = "How did Arya Stark's character get portrayed in a television adaptation?" + result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 3}}) print(f"Query: {query_1}") print(f"Answer: {result_1['answers'][0]}") print() - query_2 = "What kind of character does Arya Stark play?" - result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 1}}) + query_2 = "Why is Arya Stark an unusual character?" + result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 3}}) print(f"Query: {query_2}") print(f"Answer: {result_2['answers'][0]}") print()