Skip to content

Commit

Permalink
Backport PR #581: Expose templates for customisation in providers (#602)
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski authored Jan 30, 2024
1 parent 0246114 commit 1e28736
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 45 deletions.
51 changes: 50 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from langchain.chat_models.base import BaseChatModel
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.llms.utils import enforce_stop_tokens
from langchain.prompts import PromptTemplate
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
Expand Down Expand Up @@ -42,6 +48,23 @@
from pydantic.main import ModelMetaclass


CHAT_SYSTEM_PROMPT = """
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}.
You are talkative and you provide lots of specific details from the foundation model's context.
You may use Markdown to format your response.
Code blocks must be formatted in Markdown.
Math should be rendered with inline TeX markup, surrounded by $.
If you do not know the answer to a question, answer truthfully by responding that you do not know.
The following is a friendly conversation between you and a human.
""".strip()

CHAT_DEFAULT_TEMPLATE = """Current conversation:
{history}
Human: {input}
AI:"""


class EnvAuthStrategy(BaseModel):
"""Require one auth token via an environment variable."""

Expand Down Expand Up @@ -265,6 +288,32 @@ def get_prompt_template(self, format) -> PromptTemplate:
else:
return self.prompt_templates["text"] # Default to plain format

def get_chat_prompt_template(self) -> PromptTemplate:
"""
Produce a prompt template optimised for chat conversation.
The template should take two variables: history and input.
"""
name = self.__class__.name
if self.is_chat_provider:
return ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(
CHAT_SYSTEM_PROMPT
).format(provider_name=name, local_model_id=self.model_id),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
)
else:
return PromptTemplate(
input_variables=["history", "input"],
template=CHAT_SYSTEM_PROMPT.format(
provider_name=name, local_model_id=self.model_id
)
+ "\n\n"
+ CHAT_DEFAULT_TEMPLATE,
)

@property
def is_chat_provider(self):
return isinstance(self, BaseChatModel)
Expand Down
48 changes: 4 additions & 44 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,9 @@
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
)

from .base import BaseChatHandler, SlashCommandRoutingType

SYSTEM_PROMPT = """
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}.
You are talkative and you provide lots of specific details from the foundation model's context.
You may use Markdown to format your response.
Code blocks must be formatted in Markdown.
Math should be rendered with inline TeX markup, surrounded by $.
If you do not know the answer to a question, answer truthfully by responding that you do not know.
The following is a friendly conversation between you and a human.
""".strip()

DEFAULT_TEMPLATE = """Current conversation:
{history}
Human: {input}
AI:"""


class DefaultChatHandler(BaseChatHandler):
id = "default"
Expand All @@ -49,27 +26,10 @@ def create_llm_chain(
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

if llm.is_chat_provider:
prompt_template = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(
provider_name=provider.name, local_model_id=llm.model_id
),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
)
self.memory = ConversationBufferWindowMemory(return_messages=True, k=2)
else:
prompt_template = PromptTemplate(
input_variables=["history", "input"],
template=SYSTEM_PROMPT.format(
provider_name=provider.name, local_model_id=llm.model_id
)
+ "\n\n"
+ DEFAULT_TEMPLATE,
)
self.memory = ConversationBufferWindowMemory(k=2)
prompt_template = llm.get_chat_prompt_template()
self.memory = ConversationBufferWindowMemory(
return_messages=llm.is_chat_provider, k=2
)

self.llm = llm
self.llm_chain = ConversationChain(
Expand Down

0 comments on commit 1e28736

Please sign in to comment.