Skip to content

Commit

Permalink
refactor: simplify content processing in anthropic formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Park committed Nov 19, 2024
1 parent 8bc67d8 commit e3906b8
Show file tree
Hide file tree
Showing 14 changed files with 224 additions and 326 deletions.
7 changes: 0 additions & 7 deletions libs/genai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ This package contains the LangChain integrations for Gemini through their genera
pip install -U langchain-google-genai
```

### Image utilities
To use image utility methods, like loading images from GCS urls, install with extras group 'images':

```bash
pip install -e "langchain-google-genai[images]"
```

## Chat Models

Expand Down Expand Up @@ -61,7 +55,6 @@ The value of `image_url` can be any of the following:
- A public image URL
- An accessible gcs file (e.g., "gcs://path/to/file.png")
- A base64 encoded image (e.g., ``)
- A PIL image



Expand Down
65 changes: 11 additions & 54 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import uuid
import warnings
from io import BytesIO
from operator import itemgetter
from typing import (
Any,
Expand All @@ -22,7 +21,6 @@
Union,
cast,
)
from urllib.parse import urlparse

import google.api_core

Expand Down Expand Up @@ -114,16 +112,6 @@

from . import _genai_extension as genaix

IMAGE_TYPES: Tuple = ()
try:
import PIL
from PIL.Image import Image

IMAGE_TYPES = IMAGE_TYPES + (Image,)
except ImportError:
PIL = None # type: ignore
Image = None # type: ignore

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -245,46 +233,6 @@ def _is_openai_parts_format(part: dict) -> bool:
return "type" in part


def _is_vision_model(model: str) -> bool:
return "vision" in model


def _is_url(s: str) -> bool:
try:
result = urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False


def _is_b64(s: str) -> bool:
return s.startswith("data:image")


def _load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
try:
from google.cloud import storage # type: ignore[attr-defined]
except ImportError:
raise ImportError(
"google-cloud-storage is required to load images from GCS."
" Install it with `pip install google-cloud-storage`"
)
if PIL is None:
raise ImportError(
"PIL is required to load images. Please install it "
"with `pip install pillow`"
)

gcs_client = storage.Client(project=project)
pieces = path.split("/")
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
if len(blobs) > 1:
raise ValueError(f"Found more than one candidate for {path}!")
img_bytes = blobs[0].download_as_bytes()
return PIL.Image.open(BytesIO(img_bytes))


def _convert_to_parts(
raw_content: Union[str, Sequence[Union[str, dict]]],
) -> List[Part]:
Expand Down Expand Up @@ -368,8 +316,17 @@ def _parse_chat_history(
continue
elif isinstance(message, AIMessage):
role = "model"
raw_function_call = message.additional_kwargs.get("function_call")
if raw_function_call:
if message.tool_calls:
parts = []
for tool_call in message.tool_calls:
function_call = FunctionCall(
{
"name": tool_call["name"],
"args": tool_call["args"],
}
)
parts.append(Part(function_call=function_call))
elif raw_function_call := message.additional_kwargs.get("function_call"):
function_call = FunctionCall(
{
"name": raw_function_call["name"],
Expand Down
148 changes: 16 additions & 132 deletions libs/genai/poetry.lock

Large diffs are not rendered by default.

26 changes: 3 additions & 23 deletions libs/genai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-google-genai"
version = "2.0.4"
version = "2.0.5"
description = "An integration package connecting Google's genai package and LangChain"
authors = []
readme = "README.md"
Expand All @@ -14,12 +14,8 @@ license = "MIT"
python = ">=3.9,<4.0"
langchain-core = ">=0.3.15,<0.4"
google-generativeai = "^0.8.0"
pillow = { version = "^10.1.0", optional = true }
pydantic = ">=2,<3"

[tool.poetry.extras]
images = ["pillow"]

[tool.poetry.group.test]
optional = true

Expand All @@ -31,31 +27,24 @@ syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
numpy = "^1.26.2"
langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
langchain-standard-tests = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests" }
langchain-tests = "0.3.1"

[tool.codespell]
ignore-words-list = "rouge"




[tool.poetry.group.codespell]
optional = true

[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"




[tool.poetry.group.test_integration]
optional = true

[tool.poetry.group.test_integration.dependencies]
pillow = "^10.1.0"


pytest = "^7.3.0"


[tool.poetry.group.lint]
Expand All @@ -65,29 +54,20 @@ optional = true
ruff = "^0.1.5"




[tool.poetry.group.typing.dependencies]
mypy = "^1.10"
types-requests = "^2.28.11.5"
types-google-cloud-ndb = "^2.2.0.1"
types-pillow = "^10.1.0.2"
types-protobuf = "^4.24.0.20240302"
langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
numpy = "^1.26.2"




[tool.poetry.group.dev]
optional = true

[tool.poetry.group.dev.dependencies]
pillow = "^10.1.0"
types-requests = "^2.31.0.10"
types-pillow = "^10.1.0.2"
types-google-cloud-ndb = "^2.2.0.1"
langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }

[tool.ruff.lint]
select = [
Expand Down
19 changes: 13 additions & 6 deletions libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type
from typing import Any, List, Optional, Type, Union

from langchain_core.messages import AIMessage, ToolCall
from langchain_core.messages.tool import tool_call
Expand Down Expand Up @@ -55,11 +55,18 @@ def _pydantic_parse(self, tool_call: dict) -> BaseModel:
return cls_(**tool_call["args"])


def _extract_tool_calls(content: List[dict]) -> List[ToolCall]:
tool_calls = []
for block in content:
if block["type"] == "tool_use":
def _extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]:
"""Extract tool calls from a list of content blocks."""
if isinstance(content, list):
tool_calls = []
for block in content:
if isinstance(block, str):
continue
if block["type"] != "tool_use":
continue
tool_calls.append(
tool_call(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls
return tool_calls
else:
return []
23 changes: 16 additions & 7 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,20 @@ def validate_environment(self) -> Self:
AsyncAnthropicVertex,
)

if self.project is None:
raise ValueError("project is required for ChatAnthropicVertex")

project_id: str = self.project

self.client = AnthropicVertex(
project_id=self.project,
project_id=project_id,
region=self.location,
max_retries=self.max_retries,
access_token=self.access_token,
credentials=self.credentials,
)
self.async_client = AsyncAnthropicVertex(
project_id=self.project,
project_id=project_id,
region=self.location,
max_retries=self.max_retries,
access_token=self.access_token,
Expand Down Expand Up @@ -205,14 +210,18 @@ def _format_params(

def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
data_dict = data.model_dump()
content = [c for c in data_dict["content"] if c["type"] != "tool_use"]
content = content[0]["text"] if len(content) == 1 else content
content = data_dict["content"]
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
tool_calls = _extract_tool_calls(data_dict["content"])
if tool_calls:
msg = AIMessage(content=content, tool_calls=tool_calls)
if len(content) == 1 and content[0]["type"] == "text":
msg = AIMessage(content=content[0]["text"])
elif any(block["type"] == "tool_use" for block in content):
tool_calls = _extract_tool_calls(content)
msg = AIMessage(
content=content,
tool_calls=tool_calls,
)
else:
msg = AIMessage(content=content)
# Collect token usage
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/langchain_google_vertexai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def create_context_cache(
tool_config = _format_tool_config(tool_config)

if tools is not None:
tools = _format_to_gapic_tool(tools)
tools = [_format_to_gapic_tool(tools)]

cached_content = caching.CachedContent.create(
model_name=model.full_model_name,
Expand Down
Loading

0 comments on commit e3906b8

Please sign in to comment.