Skip to content

Commit

Permalink
feat: Add tool support to AnthropicMultiModal (#17302)
Browse files Browse the repository at this point in the history
  • Loading branch information
rushilsheth authored Dec 17, 2024
1 parent 5ac1a41 commit d1fc12e
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-bedrock"
readme = "README.md"
version = "0.3.1"
version = "0.3.2"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def _get_response_token_counts(self, raw_response: Any) -> dict:
def _complete(
self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
) -> CompletionResponse:
"""Complete the prompt with image support and optional tool calls."""
all_kwargs = self._get_model_kwargs(**kwargs)
message_dict = self._get_multi_modal_chat_messages(
prompt=prompt, role=MessageRole.USER, image_documents=image_documents
Expand All @@ -206,8 +207,17 @@ def _complete(
**all_kwargs,
)

# Handle both tool and text responses
content = response.content[0]
if hasattr(content, "input"):
# Tool response - convert to string for compatibility
text = str(content.input)
else:
# Standard text response
text = content.text

return CompletionResponse(
text=response.content[0].text,
text=text,
raw=response,
additional_kwargs=self._get_response_token_counts(response),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-multi-modal-llms-anthropic"
readme = "README.md"
version = "0.3.0"
version = "0.3.1"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest.mock import Mock, patch
from llama_index.core.multi_modal_llms.base import MultiModalLLM
from llama_index.multi_modal_llms.anthropic import AnthropicMultiModal
from llama_index.core.base.llms.types import CompletionResponse


def test_embedding_class():
Expand All @@ -10,3 +12,45 @@ def test_embedding_class():
def test_init():
m = AnthropicMultiModal(max_tokens=400)
assert m.max_tokens == 400


def test_tool_response():
"""Test handling of tool responses."""
llm = AnthropicMultiModal(max_tokens=400)

# Create mock response with tool input
mock_content = Mock()
mock_content.input = {
"booking_number": "123",
"carrier": "Test Carrier",
"total_amount": 1000.0,
}
mock_response = Mock()
mock_response.content = [mock_content]

with patch.object(llm._client.messages, "create", return_value=mock_response):
response = llm.complete(
prompt="test prompt",
image_documents=[],
tools=[
{
"name": "tms_order_payload",
"description": "Test tool",
"input_schema": {
"type": "object",
"properties": {
"booking_number": {"type": "string"},
"carrier": {"type": "string"},
"total_amount": {"type": "number"},
},
},
}
],
tool_choice={"type": "tool", "name": "tms_order_payload"},
)

assert isinstance(response, CompletionResponse)
assert isinstance(response.text, str)
assert "booking_number" in response.text
assert "123" in response.text
assert response.raw == mock_response

0 comments on commit d1fc12e

Please sign in to comment.