Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshataDM committed Oct 31, 2024
1 parent 7659404 commit dfd0c5c
Showing 1 changed file with 95 additions and 39 deletions.
134 changes: 95 additions & 39 deletions libs/community/langchain_community/chat_models/cloudflare_workersai.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,67 @@
from langchain_core.runnables.base import RunnableMap
import logging
from operator import itemgetter
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
from uuid import uuid4

import requests
from langchain_core.tools import BaseTool
from langchain_core.runnables import Runnable
from langchain.schema import AIMessage, ChatGeneration, ChatResult, HumanMessage
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.messages import BaseMessage, AIMessageChunk, ToolCall, SystemMessage, ToolMessage
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import tool_call
from uuid import uuid4
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableMap
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_logger = logging.getLogger(__name__)


def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)


def _convert_messages_to_cloudflare_messages(messages: List[BaseMessage]) -> List[Dict[str, Any]]:
def _convert_messages_to_cloudflare_messages(
messages: List[BaseMessage],
) -> List[Dict[str, Any]]:
"""Convert LangChain messages to Cloudflare Workers AI format."""
cloudflare_messages = []

msg: Dict[str, Any]
for message in messages:
# Base structure for each message
msg = {
Expand All @@ -46,37 +76,37 @@ def _convert_messages_to_cloudflare_messages(messages: List[BaseMessage]) -> Lis
msg["role"] = "assistant"
# If the AIMessage includes tool calls, format them as needed
if message.tool_calls:
msg["tool_calls"] = [
{
"name": tool_call["name"],
"arguments": tool_call["args"]
}
tool_calls = [
{"name": tool_call["name"], "arguments": tool_call["args"]}
for tool_call in message.tool_calls
]
msg["tool_calls"] = tool_calls
elif isinstance(message, SystemMessage):
msg["role"] = "system"
elif isinstance(message, ToolMessage):
msg["role"] = "tool"
msg["tool_call_id"] = message.tool_call_id # Use tool_call_id if it's a ToolMessage
msg["tool_call_id"] = (
message.tool_call_id
) # Use tool_call_id if it's a ToolMessage

# Add the formatted message to the list
cloudflare_messages.append(msg)

return cloudflare_messages


def _get_tool_calls_from_response(response: Mapping[str, Any]) -> List[ToolCall]:
def _get_tool_calls_from_response(response: requests.Response) -> List[ToolCall]:
"""Get tool calls from ollama response."""
tool_calls = []
if "tool_calls" in response.json()["result"]:
for tc in response.json()["result"]["tool_calls"]:
tool_calls.append(
tool_call(
id=str(uuid4()),
name=tc["name"],
args=tc["arguments"],
)
tool_call(
id=str(uuid4()),
name=tc["name"],
args=tc["arguments"],
)
)
return tool_calls


Expand All @@ -91,50 +121,76 @@ class CloudflareWorkersAIChatModel(BaseChatModel):
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
gateway_url: str = "https://gateway.ai.cloudflare.com/v1"

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
"""Initialize with necessary credentials."""
super().__init__(**kwargs)
if self.ai_gateway:
self.url = f"{self.gateway_url}/{self.account_id}/{self.ai_gateway}/workers-ai/run/{self.model}"
self.url = (
f"{self.gateway_url}/{self.account_id}/"
f"{self.ai_gateway}/workers-ai/run/{self.model}"
)
else:
self.url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"

def _generate(self, messages: List[BaseMessage], **kwargs: Any) -> ChatResult:
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response based on the messages provided."""
formatted_messages = _convert_messages_to_cloudflare_messages(messages)

headers = {"Authorization": f"Bearer {self.api_token}"}
prompt = "\n".join(
f"role: {msg['role']}, content: {msg['content']}" +
(f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "") +
(f", tool_call_id: {msg['tool_call_id']}" if "tool_call_id" in msg else "")
f"role: {msg['role']}, content: {msg['content']}"
+ (f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "")
+ (
f", tool_call_id: {msg['tool_call_id']}"
if "tool_call_id" in msg
else ""
)
for msg in formatted_messages
)
data = {"prompt": prompt}
if "tools" in kwargs:
tool = kwargs["tools"]
data["tools"] = [tool]

# Initialize `data` with `prompt`
data = {
"prompt": prompt,
"tools": kwargs["tools"] if "tools" in kwargs else None,
}

# Ensure `tools` is a list if it's included in `kwargs`
if data["tools"] is not None and not isinstance(data["tools"], list):
data["tools"] = [data["tools"]]

_logger.info(f"Sending prompt to Cloudflare Workers AI: {data}")

response = requests.post(self.url, headers=headers, json=data)
tool_calls = _get_tool_calls_from_response(response)
ai_message = AIMessage(content=response, tool_calls=cast(AIMessageChunk, tool_calls))
ai_message = AIMessage(
content=str(response.json()), tool_calls=cast(AIMessageChunk, tool_calls)
)
chat_generation = ChatGeneration(message=ai_message)
return ChatResult(generations=[chat_generation])

def bind_tools(self, tools: List[BaseTool], **kwargs: Any) -> Runnable[LanguageModelInput, BaseMessage]:
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable[..., Any], BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools for use in model generation."""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)

def with_structured_output(self,
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
method: Optional[Literal["json_mode"]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema."""

if kwargs:
Expand Down Expand Up @@ -185,4 +241,4 @@ def with_structured_output(self,
@property
def _llm_type(self) -> str:
"""Return the type of the LLM (for Langchain compatibility)."""
return "cloudflare-workers-ai"
return "cloudflare-workers-ai"

0 comments on commit dfd0c5c

Please sign in to comment.