From b9b84554fce98fb967d8080dae89c135450b853d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Fri, 6 Sep 2024 15:58:16 +0200 Subject: [PATCH] fix(anthropic): fixed vision and other issues with preparing messages --- gptme/llm_anthropic.py | 16 ++++++++++------ gptme/message.py | 10 +++++----- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/gptme/llm_anthropic.py b/gptme/llm_anthropic.py index 63512375..859423b9 100644 --- a/gptme/llm_anthropic.py +++ b/gptme/llm_anthropic.py @@ -36,9 +36,10 @@ class MessagePart(TypedDict, total=False): def chat(messages: list[Message], model: str) -> str: assert anthropic, "LLM not initialized" messages, system_messages = _transform_system_messages(messages) + messages_dicts = msgs2dicts(messages, anthropic=True) response = anthropic.beta.prompt_caching.messages.create( model=model, - messages=msgs2dicts(messages, anthropic=True), # type: ignore + messages=messages_dicts, # type: ignore system=system_messages, # type: ignore temperature=TEMPERATURE, top_p=TOP_P, @@ -51,11 +52,12 @@ def chat(messages: list[Message], model: str) -> str: def stream(messages: list[Message], model: str) -> Generator[str, None, None]: - messages, system_messages = _transform_system_messages(messages) assert anthropic, "LLM not initialized" + messages, system_messages = _transform_system_messages(messages) + messages_dicts = msgs2dicts(messages, anthropic=True) with anthropic.beta.prompt_caching.messages.stream( model=model, - messages=msgs2dicts(messages, anthropic=True), # type: ignore + messages=messages_dicts, # type: ignore system=system_messages, # type: ignore temperature=TEMPERATURE, top_p=TOP_P, @@ -79,16 +81,18 @@ def _transform_system_messages( messages[i] = Message( "user", content=f"{message.content}", + files=message.files, # type: ignore ) - # find consecutive user role messages and merge them into a single message + # find consecutive user role messages and merge them together messages_new: list[Message] = [] while messages: message = messages.pop(0) - if messages_new and messages_new[-1].role == "user": + if messages_new and messages_new[-1].role == "user" and message.role == "user": messages_new[-1] = Message( "user", - content=f"{messages_new[-1].content}\n{message.content}", + content=f"{messages_new[-1].content}\n\n{message.content}", + files=messages_new[-1].files + message.files, # type: ignore ) else: messages_new.append(message) diff --git a/gptme/message.py b/gptme/message.py index 07b5b0a4..0a20be76 100644 --- a/gptme/message.py +++ b/gptme/message.py @@ -51,7 +51,7 @@ def __init__( # This is not persisted to the log file. self.quiet = quiet # Files attached to the message, could e.g. be images for vision. - self.files = ( + self.files: list[Path] = ( [Path(f) if isinstance(f, str) else f for f in files] if files else [] ) @@ -123,12 +123,12 @@ def _content_files_list( def to_dict(self, keys=None, openai=False, anthropic=False) -> dict: """Return a dict representation of the message, serializable to JSON.""" content: str | list[dict[str, Any]] - if not anthropic and not openai: - # storage/wire format should keep the content as a string - content = self.content - else: + if anthropic or openai: # OpenAI format or Anthropic format should include files in the content content = self._content_files_list(openai=openai, anthropic=anthropic) + else: + # storage/wire format should keep the content as a string + content = self.content d = { "role": self.role,