Skip to content

Commit

Permalink
fix: Move check for default PromptTemplates in PromptTemplate its…
Browse files Browse the repository at this point in the history
…elf (#5018)

* make prompttemplate load the defaults instead of promptnode

* add test

* fix tenacity decorator

* fix tests

* fix error handling

* mypy

---------

Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
ZanSara and silvanocerza authored May 27, 2023
1 parent b8ff105 commit 7e5fa0d
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 150 deletions.
109 changes: 0 additions & 109 deletions haystack/nodes/prompt/legacy_default_templates.py

This file was deleted.

10 changes: 0 additions & 10 deletions haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import copy
import logging
from typing import Dict, List, Optional, Tuple, Union, Any
import warnings

import torch

Expand All @@ -11,7 +10,6 @@
from haystack.telemetry import send_event
from haystack.nodes.prompt.prompt_model import PromptModel
from haystack.nodes.prompt.prompt_template import PromptTemplate
from haystack.nodes.prompt.legacy_default_templates import LEGACY_DEFAULT_TEMPLATES

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -216,14 +214,6 @@ def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]
if isinstance(prompt_template, PromptTemplate):
return prompt_template

if prompt_template in LEGACY_DEFAULT_TEMPLATES:
warnings.warn(
f"You're using a legacy prompt template '{prompt_template}', "
"we strongly suggest you use prompts from the official Haystack PromptHub: "
"https://prompthub.deepset.ai/"
)
return LEGACY_DEFAULT_TEMPLATES[prompt_template]

# If it's the name of a template that was used already
if prompt_template in self._prompt_templates_cache:
return self._prompt_templates_cache[prompt_template]
Expand Down
164 changes: 137 additions & 27 deletions haystack/nodes/prompt/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import ast
import json
import warnings
from pathlib import Path
from abc import ABC
from uuid import uuid4
Expand Down Expand Up @@ -46,6 +47,112 @@
PROMPTHUB_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))


#############################################################################
# This templates were hardcoded in the prompt_template module. When adding
# support for PromptHub integration we decided to remove them with the PR
# that added the integration: https://github.com/deepset-ai/haystack/pull/4879/
#
# That PR also changed the PromptNode API forcing the user to change how
# they use the node.
#
# After some discussion we deemed the change to be too breaking for existing
# use cases and which steps would have been necessary to migrate to the
# new API in case someone was using an harcoded template we decided to
# bring them back.
#
# So for the time being this must live here, no new template must be added
# to this dictionary.
#############################################################################


LEGACY_DEFAULT_TEMPLATES: Dict[str, Dict[str, Any]] = {
# DO NOT ADD ANY NEW TEMPLATE IN HERE!
"question-answering": {
"prompt": "Given the context please answer the question. Context: {join(documents)}; Question: "
"{query}; Answer:",
"output_parser": AnswerParser(),
},
"question-answering-per-document": {
"prompt": "Given the context please answer the question. Context: {documents}; Question: " "{query}; Answer:",
"output_parser": AnswerParser(),
},
"question-answering-with-references": {
"prompt": "Create a concise and informative answer (no more than 50 words) for a given question "
"based solely on the given documents. You must only use information from the given documents. "
"Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. "
"If multiple documents contain the answer, cite those documents like ‘as stated in Document[number], Document[number], etc.’. "
"If the documents do not contain the answer to the question, say that ‘answering is not possible given the available information.’\n"
"{join(documents, delimiter=new_line, pattern=new_line+'Document[$idx]: $content', str_replace={new_line: ' ', '[': '(', ']': ')'})} \n Question: {query}; Answer: ",
"output_parser": AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
},
"question-answering-with-document-scores": {
"prompt": "Answer the following question using the paragraphs below as sources. "
"An answer should be short, a few words at most.\n"
"Paragraphs:\n{documents}\n"
"Question: {query}\n\n"
"Instructions: Consider all the paragraphs above and their corresponding scores to generate "
"the answer. While a single paragraph may have a high score, it's important to consider all "
"paragraphs for the same answer candidate to answer accurately.\n\n"
"After having considered all possibilities, the final answer is:\n"
},
"question-generation": {"prompt": "Given the context please generate a question. Context: {documents}; Question:"},
"conditioned-question-generation": {
"prompt": "Please come up with a question for the given context and the answer. "
"Context: {documents}; Answer: {answers}; Question:"
},
"summarization": {"prompt": "Summarize this document: {documents} Summary:"},
"question-answering-check": {
"prompt": "Does the following context contain the answer to the question? "
"Context: {documents}; Question: {query}; Please answer yes or no! Answer:",
"output_parser": AnswerParser(),
},
"sentiment-analysis": {
"prompt": "Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: {documents}; Answer:"
},
"multiple-choice-question-answering": {
"prompt": "Question:{query} ; Choose the most suitable option to answer the above question. "
"Options: {options}; Answer:",
"output_parser": AnswerParser(),
},
"topic-classification": {"prompt": "Categories: {options}; What category best describes: {documents}; Answer:"},
"language-detection": {
"prompt": "Detect the language in the following context and answer with the "
"name of the language. Context: {documents}; Answer:"
},
"translation": {
"prompt": "Translate the following context to {target_language}. Context: {documents}; Translation:"
},
"zero-shot-react": {
"prompt": "You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
"correctly, you have access to the following tools:\n\n"
"{tool_names_with_descriptions}\n\n"
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
"for a final answer, respond with the `Final Answer:`\n\n"
"Use the following format:\n\n"
"Question: the question to be answered\n"
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
"Tool: pick one of {tool_names} \n"
"Tool Input: the input for the tool\n"
"Observation: the tool will respond with the result\n"
"...\n"
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
"---\n\n"
"Question: {query}\n"
"Thought: Let's think step-by-step, I first need to {transcript}"
},
"conversational-agent": {
"prompt": "The following is a conversation between a human and an AI.\n{history}\nHuman: {query}\nAI:"
},
"conversational-summary": {
"prompt": "Condense the following chat transcript by shortening and summarizing the content without losing important information:\n{chat_transcript}\nCondensed Transcript:"
},
# DO NOT ADD ANY NEW TEMPLATE IN HERE!
}


class PromptNotFoundError(Exception):
...

Expand Down Expand Up @@ -217,32 +324,37 @@ def __init__(self, prompt: str, output_parser: Optional[Union[BaseOutputParser,
super().__init__()
name, prompt_text = "", ""

try:
# if it looks like a prompt template name
if re.fullmatch(r"[-a-zA-Z0-9_/]+", prompt):
name = prompt
prompt_text = self._fetch_from_prompthub(prompt)

# if it's a path to a YAML file
elif Path(prompt).exists():
with open(prompt, "r", encoding="utf-8") as yaml_file:
prompt_template_parsed = yaml.safe_load(yaml_file.read())
if not isinstance(prompt_template_parsed, dict):
raise ValueError("The prompt loaded is not a prompt YAML file.")
name = prompt_template_parsed["name"]
prompt_text = prompt_template_parsed["prompt_text"]

# Otherwise it's a on-the-fly prompt text
else:
prompt_text = prompt
name = "custom-at-query-time"

except OSError as exc:
logger.info(
"There was an error checking whether this prompt is a file (%s). Haystack will assume it's not.",
str(exc),
if prompt in LEGACY_DEFAULT_TEMPLATES:
warnings.warn(
f"You're using a legacy prompt template '{prompt}', "
"we strongly suggest you use prompts from the official Haystack PromptHub: "
"https://prompthub.deepset.ai/"
)
# In case of errors, let's directly assume this is a text prompt
name = prompt
prompt_text = LEGACY_DEFAULT_TEMPLATES[prompt]["prompt"]
output_parser = LEGACY_DEFAULT_TEMPLATES[prompt].get("output_parser")

# if it looks like a prompt template name
elif re.fullmatch(r"[-a-zA-Z0-9_/]+", prompt):
name = prompt
try:
prompt_text = self._fetch_from_prompthub(prompt)
except HTTPError as http_error:
if http_error.response.status_code != 404:
raise http_error
raise PromptNotFoundError(f"Prompt template named '{name}' not available in the Prompt Hub.")

# if it's a path to a YAML file
elif len(prompt) < 255 and Path(prompt).exists():
with open(prompt, "r", encoding="utf-8") as yaml_file:
prompt_template_parsed = yaml.safe_load(yaml_file.read())
if not isinstance(prompt_template_parsed, dict):
raise ValueError("The prompt loaded is not a prompt YAML file.")
name = prompt_template_parsed["name"]
prompt_text = prompt_template_parsed["prompt_text"]

# Otherwise it's a on-the-fly prompt text
else:
prompt_text = prompt
name = "custom-at-query-time"

Expand Down Expand Up @@ -296,8 +408,6 @@ def output_variable(self) -> Optional[str]:
def _fetch_from_prompthub(self, name) -> str:
"""
Looks for the given prompt in the PromptHub if the prompt is not in the local cache.
Raises PromptNotFoundError if the prompt is not present in the hub.
"""
try:
prompt_data: prompthub.Prompt = prompthub.fetch(name, timeout=PROMPTHUB_TIMEOUT)
Expand Down
6 changes: 3 additions & 3 deletions test/prompt/test_prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, DefaultTokenStreamingHandler
from haystack.nodes.prompt.legacy_default_templates import LEGACY_DEFAULT_TEMPLATES


@pytest.fixture
Expand Down Expand Up @@ -103,8 +103,8 @@ def test_get_prompt_template_no_default_template(mock_model):
def test_get_prompt_template_from_legacy_default_template(mock_model):
node = PromptNode()
template = node.get_prompt_template("question-answering")
assert template.name == "custom-at-query-time"
assert template.prompt_text == LEGACY_DEFAULT_TEMPLATES["question-answering"].prompt_text
assert template.name == "question-answering"
assert template.prompt_text == LEGACY_DEFAULT_TEMPLATES["question-answering"]["prompt"]


@pytest.mark.unit
Expand Down
11 changes: 10 additions & 1 deletion test/prompt/test_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

from haystack.nodes.prompt import PromptTemplate
from haystack.nodes.prompt.prompt_node import PromptNode
from haystack.nodes.prompt.prompt_template import PromptTemplateValidationError
from haystack.nodes.prompt.prompt_template import PromptTemplateValidationError, LEGACY_DEFAULT_TEMPLATES
from haystack.nodes.prompt.shapers import AnswerParser
from haystack.pipelines.base import Pipeline
from haystack.schema import Answer, Document


@pytest.fixture
def mock_prompthub():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [
Expand All @@ -28,6 +29,14 @@ def test_prompt_templates_from_hub():
mock_prompthub.fetch.assert_called_with("deepset/question-answering", timeout=30)


@pytest.mark.unit
def test_prompt_templates_from_legacy_set(mock_prompthub):
p = PromptTemplate("question-answering")
assert p.name == "question-answering"
assert p.prompt_text == LEGACY_DEFAULT_TEMPLATES["question-answering"]["prompt"]
mock_prompthub.assert_not_called()


@pytest.mark.unit
def test_prompt_templates_from_file(tmp_path):
path = tmp_path / "test-prompt.yml"
Expand Down

0 comments on commit 7e5fa0d

Please sign in to comment.