Skip to content

Commit

Permalink
vertexai: refactor: simplify content processing in anthropic formatter (
Browse files Browse the repository at this point in the history
  • Loading branch information
jfypk authored Nov 20, 2024
1 parent 9f520cd commit 0b5d16a
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 18 deletions.
19 changes: 13 additions & 6 deletions libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type
from typing import Any, List, Optional, Type, Union

from langchain_core.messages import AIMessage, ToolCall
from langchain_core.messages.tool import tool_call
Expand Down Expand Up @@ -55,11 +55,18 @@ def _pydantic_parse(self, tool_call: dict) -> BaseModel:
return cls_(**tool_call["args"])


def _extract_tool_calls(content: List[dict]) -> List[ToolCall]:
tool_calls = []
for block in content:
if block["type"] == "tool_use":
def _extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]:
"""Extract tool calls from a list of content blocks."""
if isinstance(content, list):
tool_calls = []
for block in content:
if isinstance(block, str):
continue
if block["type"] != "tool_use":
continue
tool_calls.append(
tool_call(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls
return tool_calls
else:
return []
23 changes: 16 additions & 7 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,20 @@ def validate_environment(self) -> Self:
AsyncAnthropicVertex,
)

if self.project is None:
raise ValueError("project is required for ChatAnthropicVertex")

project_id: str = self.project

self.client = AnthropicVertex(
project_id=self.project,
project_id=project_id,
region=self.location,
max_retries=self.max_retries,
access_token=self.access_token,
credentials=self.credentials,
)
self.async_client = AsyncAnthropicVertex(
project_id=self.project,
project_id=project_id,
region=self.location,
max_retries=self.max_retries,
access_token=self.access_token,
Expand Down Expand Up @@ -205,14 +210,18 @@ def _format_params(

def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
data_dict = data.model_dump()
content = [c for c in data_dict["content"] if c["type"] != "tool_use"]
content = content[0]["text"] if len(content) == 1 else content
content = data_dict["content"]
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
tool_calls = _extract_tool_calls(data_dict["content"])
if tool_calls:
msg = AIMessage(content=content, tool_calls=tool_calls)
if len(content) == 1 and content[0]["type"] == "text":
msg = AIMessage(content=content[0]["text"])
elif any(block["type"] == "tool_use" for block in content):
tool_calls = _extract_tool_calls(content)
msg = AIMessage(
content=content,
tool_calls=tool_calls,
)
else:
msg = AIMessage(content=content)
# Collect token usage
Expand Down
55 changes: 51 additions & 4 deletions libs/vertexai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/vertexai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ numpy = [
google-api-python-client = "^2.117.0"
langchain = "^0.3.7"
langchain-tests = "0.3.1"
anthropic = { extras = ["vertexai"], version = ">=0.35.0,<1" }


[tool.codespell]
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/tests/integration_tests/test_model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async def test_anthropic_async() -> None:
def _check_tool_calls(response: BaseMessage, expected_name: str) -> None:
"""Check tool calls are as expected."""
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert isinstance(response.content, list)
tool_calls = response.tool_calls
assert len(tool_calls) == 1
tool_call = tool_calls[0]
Expand Down
47 changes: 47 additions & 0 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_parse_examples,
_parse_response_candidate,
)
from langchain_google_vertexai.model_garden import ChatAnthropicVertex


def test_init() -> None:
Expand Down Expand Up @@ -1067,3 +1068,49 @@ def test_init_client_with_custom_api() -> None:
transport = mock_prediction_service.call_args.kwargs["transport"]
assert client_options.api_endpoint == "https://example.com"
assert transport == "rest"


def test_anthropic_format_output() -> None:
"""Test format output handles different content structures correctly."""

@dataclass
class Usage:
input_tokens: int
output_tokens: int

@dataclass
class Message:
def model_dump(self):
return {
"content": [
{
"type": "tool_use",
"id": "123",
"name": "calculator",
"input": {"number": 42},
}
],
"model": "baz",
"role": "assistant",
"usage": Usage(input_tokens=2, output_tokens=1),
"type": "message",
}

usage: Usage

test_msg = Message(usage=Usage(input_tokens=2, output_tokens=1))

model = ChatAnthropicVertex(project="test-project", location="test-location")
result = model._format_output(test_msg)

message = result.generations[0].message
assert isinstance(message, AIMessage)
assert message.content == test_msg.model_dump()["content"]
assert len(message.tool_calls) == 1
assert message.tool_calls[0]["name"] == "calculator"
assert message.tool_calls[0]["args"] == {"number": 42}
assert message.usage_metadata == {
"input_tokens": 2,
"output_tokens": 1,
"total_tokens": 3,
}

0 comments on commit 0b5d16a

Please sign in to comment.