-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
861 additions
and
205 deletions.
There are no files selected for viewing
93 changes: 93 additions & 0 deletions
93
libs/vertexai/langchain_google_vertexai/_anthropic_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import re | ||
from typing import Dict, List, Optional, Tuple, Union | ||
|
||
from langchain_core.messages import BaseMessage | ||
|
||
_message_type_lookups = {"human": "user", "ai": "assistant"} | ||
|
||
|
||
def _format_image(image_url: str) -> Dict: | ||
"""Formats a message image to a dict for anthropic api.""" | ||
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$" | ||
match = re.match(regex, image_url) | ||
if match is None: | ||
raise ValueError( | ||
"Anthropic only supports base64-encoded images currently." | ||
" Example: data:image/png;base64,'/9j/4AAQSk'..." | ||
) | ||
return { | ||
"type": "base64", | ||
"media_type": match.group("media_type"), | ||
"data": match.group("data"), | ||
} | ||
|
||
|
||
def _format_messages_anthropic( | ||
messages: List[BaseMessage], | ||
) -> Tuple[Optional[str], List[Dict]]: | ||
"""Formats messages for anthropic.""" | ||
system_message: Optional[str] = None | ||
formatted_messages: List[Dict] = [] | ||
|
||
for i, message in enumerate(messages): | ||
if message.type == "system": | ||
if i != 0: | ||
raise ValueError("System message must be at beginning of message list.") | ||
if not isinstance(message.content, str): | ||
raise ValueError( | ||
"System message must be a string, " | ||
f"instead was: {type(message.content)}" | ||
) | ||
system_message = message.content | ||
continue | ||
|
||
role = _message_type_lookups[message.type] | ||
content: Union[str, List[Dict]] | ||
|
||
if not isinstance(message.content, str): | ||
# parse as dict | ||
assert isinstance( | ||
message.content, list | ||
), "Anthropic message content must be str or list of dicts" | ||
|
||
# populate content | ||
content = [] | ||
for item in message.content: | ||
if isinstance(item, str): | ||
content.append( | ||
{ | ||
"type": "text", | ||
"text": item, | ||
} | ||
) | ||
elif isinstance(item, dict): | ||
if "type" not in item: | ||
raise ValueError("Dict content item must have a type key") | ||
elif item["type"] == "image_url": | ||
# convert format | ||
source = _format_image(item["image_url"]["url"]) | ||
content.append( | ||
{ | ||
"type": "image", | ||
"source": source, | ||
} | ||
) | ||
elif item["type"] == "tool_use": | ||
item.pop("text", None) | ||
content.append(item) | ||
else: | ||
content.append(item) | ||
else: | ||
raise ValueError( | ||
f"Content items must be str or dict, instead was: {type(item)}" | ||
) | ||
else: | ||
content = message.content | ||
|
||
formatted_messages.append( | ||
{ | ||
"role": role, | ||
"content": content, | ||
} | ||
) | ||
return system_message, formatted_messages |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.