Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Move check for default PromptTemplates in PromptTemplate itself #5018

Merged
merged 9 commits into from
May 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 0 additions & 109 deletions haystack/nodes/prompt/legacy_default_templates.py

This file was deleted.

10 changes: 0 additions & 10 deletions haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import copy
import logging
from typing import Dict, List, Optional, Tuple, Union, Any
import warnings

import torch

Expand All @@ -11,7 +10,6 @@
from haystack.telemetry import send_event
from haystack.nodes.prompt.prompt_model import PromptModel
from haystack.nodes.prompt.prompt_template import PromptTemplate
from haystack.nodes.prompt.legacy_default_templates import LEGACY_DEFAULT_TEMPLATES

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -216,14 +214,6 @@ def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]
if isinstance(prompt_template, PromptTemplate):
return prompt_template

if prompt_template in LEGACY_DEFAULT_TEMPLATES:
warnings.warn(
f"You're using a legacy prompt template '{prompt_template}', "
"we strongly suggest you use prompts from the official Haystack PromptHub: "
"https://prompthub.deepset.ai/"
)
return LEGACY_DEFAULT_TEMPLATES[prompt_template]

# If it's the name of a template that was used already
if prompt_template in self._prompt_templates_cache:
return self._prompt_templates_cache[prompt_template]
Expand Down
164 changes: 137 additions & 27 deletions haystack/nodes/prompt/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import ast
import json
import warnings
from pathlib import Path
from abc import ABC
from uuid import uuid4
Expand Down Expand Up @@ -46,6 +47,112 @@
PROMPTHUB_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))


#############################################################################
# This templates were hardcoded in the prompt_template module. When adding
# support for PromptHub integration we decided to remove them with the PR
# that added the integration: https://github.com/deepset-ai/haystack/pull/4879/
#
# That PR also changed the PromptNode API forcing the user to change how
# they use the node.
#
# After some discussion we deemed the change to be too breaking for existing
# use cases and which steps would have been necessary to migrate to the
# new API in case someone was using an harcoded template we decided to
# bring them back.
#
# So for the time being this must live here, no new template must be added
# to this dictionary.
#############################################################################


LEGACY_DEFAULT_TEMPLATES: Dict[str, Dict[str, Any]] = {
# DO NOT ADD ANY NEW TEMPLATE IN HERE!
"question-answering": {
"prompt": "Given the context please answer the question. Context: {join(documents)}; Question: "
"{query}; Answer:",
"output_parser": AnswerParser(),
},
"question-answering-per-document": {
"prompt": "Given the context please answer the question. Context: {documents}; Question: " "{query}; Answer:",
"output_parser": AnswerParser(),
},
"question-answering-with-references": {
"prompt": "Create a concise and informative answer (no more than 50 words) for a given question "
"based solely on the given documents. You must only use information from the given documents. "
"Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. "
"If multiple documents contain the answer, cite those documents like ‘as stated in Document[number], Document[number], etc.’. "
"If the documents do not contain the answer to the question, say that ‘answering is not possible given the available information.’\n"
"{join(documents, delimiter=new_line, pattern=new_line+'Document[$idx]: $content', str_replace={new_line: ' ', '[': '(', ']': ')'})} \n Question: {query}; Answer: ",
"output_parser": AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
},
"question-answering-with-document-scores": {
"prompt": "Answer the following question using the paragraphs below as sources. "
"An answer should be short, a few words at most.\n"
"Paragraphs:\n{documents}\n"
"Question: {query}\n\n"
"Instructions: Consider all the paragraphs above and their corresponding scores to generate "
"the answer. While a single paragraph may have a high score, it's important to consider all "
"paragraphs for the same answer candidate to answer accurately.\n\n"
"After having considered all possibilities, the final answer is:\n"
},
"question-generation": {"prompt": "Given the context please generate a question. Context: {documents}; Question:"},
"conditioned-question-generation": {
"prompt": "Please come up with a question for the given context and the answer. "
"Context: {documents}; Answer: {answers}; Question:"
},
"summarization": {"prompt": "Summarize this document: {documents} Summary:"},
"question-answering-check": {
"prompt": "Does the following context contain the answer to the question? "
"Context: {documents}; Question: {query}; Please answer yes or no! Answer:",
"output_parser": AnswerParser(),
},
"sentiment-analysis": {
"prompt": "Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: {documents}; Answer:"
},
"multiple-choice-question-answering": {
"prompt": "Question:{query} ; Choose the most suitable option to answer the above question. "
"Options: {options}; Answer:",
"output_parser": AnswerParser(),
},
"topic-classification": {"prompt": "Categories: {options}; What category best describes: {documents}; Answer:"},
"language-detection": {
"prompt": "Detect the language in the following context and answer with the "
"name of the language. Context: {documents}; Answer:"
},
"translation": {
"prompt": "Translate the following context to {target_language}. Context: {documents}; Translation:"
},
"zero-shot-react": {
"prompt": "You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
"correctly, you have access to the following tools:\n\n"
"{tool_names_with_descriptions}\n\n"
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
"for a final answer, respond with the `Final Answer:`\n\n"
"Use the following format:\n\n"
"Question: the question to be answered\n"
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
"Tool: pick one of {tool_names} \n"
"Tool Input: the input for the tool\n"
"Observation: the tool will respond with the result\n"
"...\n"
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
"---\n\n"
"Question: {query}\n"
"Thought: Let's think step-by-step, I first need to {transcript}"
},
"conversational-agent": {
"prompt": "The following is a conversation between a human and an AI.\n{history}\nHuman: {query}\nAI:"
},
"conversational-summary": {
"prompt": "Condense the following chat transcript by shortening and summarizing the content without losing important information:\n{chat_transcript}\nCondensed Transcript:"
},
# DO NOT ADD ANY NEW TEMPLATE IN HERE!
}


class PromptNotFoundError(Exception):
...

Expand Down Expand Up @@ -217,32 +324,37 @@ def __init__(self, prompt: str, output_parser: Optional[Union[BaseOutputParser,
super().__init__()
name, prompt_text = "", ""

try:
# if it looks like a prompt template name
if re.fullmatch(r"[-a-zA-Z0-9_/]+", prompt):
name = prompt
prompt_text = self._fetch_from_prompthub(prompt)

# if it's a path to a YAML file
elif Path(prompt).exists():
with open(prompt, "r", encoding="utf-8") as yaml_file:
prompt_template_parsed = yaml.safe_load(yaml_file.read())
if not isinstance(prompt_template_parsed, dict):
raise ValueError("The prompt loaded is not a prompt YAML file.")
name = prompt_template_parsed["name"]
prompt_text = prompt_template_parsed["prompt_text"]

# Otherwise it's a on-the-fly prompt text
else:
prompt_text = prompt
name = "custom-at-query-time"

except OSError as exc:
logger.info(
"There was an error checking whether this prompt is a file (%s). Haystack will assume it's not.",
str(exc),
if prompt in LEGACY_DEFAULT_TEMPLATES:
warnings.warn(
f"You're using a legacy prompt template '{prompt}', "
"we strongly suggest you use prompts from the official Haystack PromptHub: "
"https://prompthub.deepset.ai/"
)
# In case of errors, let's directly assume this is a text prompt
name = prompt
prompt_text = LEGACY_DEFAULT_TEMPLATES[prompt]["prompt"]
output_parser = LEGACY_DEFAULT_TEMPLATES[prompt].get("output_parser")

# if it looks like a prompt template name
elif re.fullmatch(r"[-a-zA-Z0-9_/]+", prompt):
name = prompt
try:
prompt_text = self._fetch_from_prompthub(prompt)
except HTTPError as http_error:
if http_error.response.status_code != 404:
raise http_error
raise PromptNotFoundError(f"Prompt template named '{name}' not available in the Prompt Hub.")

# if it's a path to a YAML file
elif len(prompt) < 255 and Path(prompt).exists():
with open(prompt, "r", encoding="utf-8") as yaml_file:
prompt_template_parsed = yaml.safe_load(yaml_file.read())
if not isinstance(prompt_template_parsed, dict):
raise ValueError("The prompt loaded is not a prompt YAML file.")
name = prompt_template_parsed["name"]
prompt_text = prompt_template_parsed["prompt_text"]

# Otherwise it's a on-the-fly prompt text
else:
prompt_text = prompt
name = "custom-at-query-time"

Expand Down Expand Up @@ -296,8 +408,6 @@ def output_variable(self) -> Optional[str]:
def _fetch_from_prompthub(self, name) -> str:
"""
Looks for the given prompt in the PromptHub if the prompt is not in the local cache.

Raises PromptNotFoundError if the prompt is not present in the hub.
"""
try:
prompt_data: prompthub.Prompt = prompthub.fetch(name, timeout=PROMPTHUB_TIMEOUT)
Expand Down
6 changes: 3 additions & 3 deletions test/prompt/test_prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, DefaultTokenStreamingHandler
from haystack.nodes.prompt.legacy_default_templates import LEGACY_DEFAULT_TEMPLATES


@pytest.fixture
Expand Down Expand Up @@ -103,8 +103,8 @@ def test_get_prompt_template_no_default_template(mock_model):
def test_get_prompt_template_from_legacy_default_template(mock_model):
node = PromptNode()
template = node.get_prompt_template("question-answering")
assert template.name == "custom-at-query-time"
assert template.prompt_text == LEGACY_DEFAULT_TEMPLATES["question-answering"].prompt_text
assert template.name == "question-answering"
assert template.prompt_text == LEGACY_DEFAULT_TEMPLATES["question-answering"]["prompt"]


@pytest.mark.unit
Expand Down
11 changes: 10 additions & 1 deletion test/prompt/test_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

from haystack.nodes.prompt import PromptTemplate
from haystack.nodes.prompt.prompt_node import PromptNode
from haystack.nodes.prompt.prompt_template import PromptTemplateValidationError
from haystack.nodes.prompt.prompt_template import PromptTemplateValidationError, LEGACY_DEFAULT_TEMPLATES
from haystack.nodes.prompt.shapers import AnswerParser
from haystack.pipelines.base import Pipeline
from haystack.schema import Answer, Document


@pytest.fixture
def mock_prompthub():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [
Expand All @@ -28,6 +29,14 @@ def test_prompt_templates_from_hub():
mock_prompthub.fetch.assert_called_with("deepset/question-answering", timeout=30)


@pytest.mark.unit
def test_prompt_templates_from_legacy_set(mock_prompthub):
p = PromptTemplate("question-answering")
assert p.name == "question-answering"
assert p.prompt_text == LEGACY_DEFAULT_TEMPLATES["question-answering"]["prompt"]
mock_prompthub.assert_not_called()


@pytest.mark.unit
def test_prompt_templates_from_file(tmp_path):
path = tmp_path / "test-prompt.yml"
Expand Down