Skip to content

Commit

Permalink
feat: added new 'hint' wrappers that inject hints into the pre-prefix (
Browse files Browse the repository at this point in the history
…#707)

* added new 'hint' wrappers that inject hints into the pre-prefix

* modified basic search functions with extra input sanitization

* updated first message prefix
  • Loading branch information
cpacker authored Dec 25, 2023
1 parent 88630c1 commit 399fa39
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 4 deletions.
18 changes: 18 additions & 0 deletions memgpt/functions/function_sets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def conversation_search(self, query: str, page: Optional[int] = 0) -> Optional[s
Returns:
str: Query result string
"""
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
except:
raise ValueError(f"'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results, total = self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
Expand All @@ -119,6 +125,12 @@ def conversation_search_date(self, start_date: str, end_date: str, page: Optiona
Returns:
str: Query result string
"""
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
except:
raise ValueError(f"'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results, total = self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
Expand Down Expand Up @@ -156,6 +168,12 @@ def archival_memory_search(self, query: str, page: Optional[int] = 0) -> Optiona
Returns:
str: Query result string
"""
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
except:
raise ValueError(f"'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results, total = self.persistence_manager.archival_memory.search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
Expand Down
3 changes: 2 additions & 1 deletion memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def get_chat_completion(

# First step: turn the message sequence into a prompt that the model expects
try:
if hasattr(llm_wrapper, "supports_first_message") and llm_wrapper.supports_first_message:
# if hasattr(llm_wrapper, "supports_first_message") and llm_wrapper.supports_first_message:
if hasattr(llm_wrapper, "supports_first_message"):
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions, first_message=first_message)
else:
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
Expand Down
31 changes: 29 additions & 2 deletions memgpt/local_llm/llm_chat_completion_wrappers/chatml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
from ...errors import LLMJSONParsingError


PREFIX_HINT = """# Reminders:
# Important information about yourself and the user is stored in (limited) core memory
# You can modify core memory with core_memory_replace
# You can add to core memory with core_memory_append
# Less important information is stored in (unlimited) archival memory
# You can add to archival memory with archival_memory_insert
# You can search archival memory with archival_memory_search
# You will always see the statistics of archival memory, so you know if there is content inside it
# If you receive new important information about the user (or yourself), you immediately update your memory with core_memory_replace, core_memory_append, or archival_memory_insert"""

FIRST_PREFIX_HINT = """# Reminders:
# This is your first interaction with the user!
# Initial information about them is provided in the core memory user block
# Make sure to introduce yourself to them
# Your inner thoughts should be private, interesting, and creative
# Do NOT use inner thoughts to communicate with the user
# Use send_message to communicate with the user"""
# Don't forget to use send_message, otherwise the user won't see your message"""


class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
"""ChatML-style prompt formatter, tested for use with https://huggingface.co/ehartford/dolphin-2.5-mixtral-8x7b#training"""

Expand All @@ -24,12 +44,15 @@ def __init__(
allow_function_role=False, # use function role for function replies?
no_function_role_role="assistant", # if no function role, which role to use?
no_function_role_prefix="FUNCTION RETURN:\n", # if no function role, what prefix to use?
# add a guiding hint
assistant_prefix_hint=False,
):
self.simplify_json_content = simplify_json_content
self.clean_func_args = clean_function_args
self.include_assistant_prefix = include_assistant_prefix
self.assistant_prefix_extra = assistant_prefix_extra
self.assistant_prefix_extra_first_message = assistant_prefix_extra_first_message
self.assistant_prefix_hint = assistant_prefix_hint

# role-based
self.allow_custom_roles = allow_custom_roles
Expand Down Expand Up @@ -202,7 +225,9 @@ def chat_completion_to_prompt(self, messages, functions, first_message=False):

if self.include_assistant_prefix:
prompt += f"\n<|im_start|>assistant"
if first_message:
if self.assistant_prefix_hint:
prompt += f"\n{FIRST_PREFIX_HINT if first_message else PREFIX_HINT}"
if self.supports_first_message and first_message:
if self.assistant_prefix_extra_first_message:
prompt += self.assistant_prefix_extra_first_message
else:
Expand Down Expand Up @@ -355,7 +380,9 @@ def output_to_chat_completion_response(self, raw_llm_output, first_message=False
"""

# If we used a prefex to guide generation, we need to add it to the output as a preefix
assistant_prefix = self.assistant_prefix_extra_first_message if first_message else self.assistant_prefix_extra
assistant_prefix = (
self.assistant_prefix_extra_first_message if (self.supports_first_message and first_message) else self.assistant_prefix_extra
)
if assistant_prefix and raw_llm_output[: len(assistant_prefix)] != assistant_prefix:
raw_llm_output = assistant_prefix + raw_llm_output

Expand Down
2 changes: 1 addition & 1 deletion memgpt/local_llm/settings/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# '\n' +
# '</s>',
# '<|',
# '\n#',
"\n#",
# "\n\n\n",
# prevent chaining function calls / multi json objects / run-on generations
# NOTE: this requires the ability to patch the extra '}}' back into the prompt
Expand Down
3 changes: 3 additions & 0 deletions memgpt/local_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def get_available_wrappers() -> dict:
# New chatml-based wrappers
"chatml": chatml.ChatMLInnerMonologueWrapper(),
"chatml-noforce": chatml.ChatMLOuterInnerMonologueWrapper(),
# With extra hints
"chatml-hints": chatml.ChatMLInnerMonologueWrapper(assistant_prefix_hint=True),
"chatml-noforce-hints": chatml.ChatMLOuterInnerMonologueWrapper(assistant_prefix_hint=True),
# Legacy wrappers
"airoboros-l2-70b-2.1": airoboros.Airoboros21InnerMonologueWrapper(),
"airoboros-l2-70b-2.1-grammar": airoboros.Airoboros21InnerMonologueWrapper(assistant_prefix_extra=None),
Expand Down

0 comments on commit 399fa39

Please sign in to comment.