From ff661ea3c682a5813bb99e306d846affcc14b9d4 Mon Sep 17 00:00:00 2001 From: Leo He <1995leohe@gmail.com> Date: Wed, 18 Dec 2024 19:00:55 -0800 Subject: [PATCH] Add OCI Generative AI tool calling support (#16888) --- docs/docs/examples/llm/oci_genai.ipynb | 59 ++- .../llama_index/llms/oci_genai/base.py | 274 ++++++++---- .../llama_index/llms/oci_genai/utils.py | 395 ++++++++++++++++-- .../llama-index-llms-oci-genai/pyproject.toml | 6 +- .../tests/test_oci_genai.py | 196 ++++++++- 5 files changed, 789 insertions(+), 141 deletions(-) diff --git a/docs/docs/examples/llm/oci_genai.ipynb b/docs/docs/examples/llm/oci_genai.ipynb index 3caeb6ca6d1d0..52129cb5f8b73 100644 --- a/docs/docs/examples/llm/oci_genai.ipynb +++ b/docs/docs/examples/llm/oci_genai.ipynb @@ -1,14 +1,5 @@ { "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "id": "6d1ca9ac", - "metadata": {}, - "source": [ - "\"Open" - ] - }, { "cell_type": "markdown", "id": "9e3a8796-edc8-43f2-94ad-fe4fb20d70ed", @@ -360,6 +351,56 @@ "resp = llm.chat(messages)\n", "print(resp)" ] + }, + { + "cell_type": "markdown", + "id": "acd73b3d", + "metadata": {}, + "source": [ + "## Basic tool calling in llamaindex \n", + "\n", + "Only Cohere supports tool calling for now" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5546c661", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.oci_genai import OCIGenAI\n", + "from llama_index.core.tools import FunctionTool\n", + "\n", + "llm = OCIGenAI(\n", + " model=\"MY_MODEL\",\n", + " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", + " compartment_id=\"MY_OCID\",\n", + ")\n", + "\n", + "\n", + "def multiply(a: int, b: int) -> int:\n", + " \"\"\"Multiple two integers and returns the result integer\"\"\"\n", + " return a * b\n", + "\n", + "\n", + "def add(a: int, b: int) -> int:\n", + " \"\"\"Addition function on two integers.\"\"\"\n", + " return a + b\n", + "\n", + "\n", + "add_tool = FunctionTool.from_defaults(fn=add)\n", + "multiply_tool = FunctionTool.from_defaults(fn=multiply)\n", + "\n", + "response = llm.chat_with_tools(\n", + " tools=[add_tool, multiply_tool],\n", + " user_msg=\"What is 3 * 12? Also, what is 11 + 49?\",\n", + ")\n", + "\n", + "print(response)\n", + "tool_calls = response.message.additional_kwargs.get(\"tool_calls\", [])\n", + "print(tool_calls)" + ] } ], "metadata": { diff --git a/llama-index-integrations/llms/llama-index-llms-oci-genai/llama_index/llms/oci_genai/base.py b/llama-index-integrations/llms/llama-index-llms-oci-genai/llama_index/llms/oci_genai/base.py index f0df254601e8e..b2413b21517dd 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-genai/llama_index/llms/oci_genai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-genai/llama_index/llms/oci_genai/base.py @@ -1,5 +1,5 @@ import json -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence, List, Union, TYPE_CHECKING from llama_index.core.base.llms.types import ( ChatMessage, @@ -12,6 +12,10 @@ LLMMetadata, MessageRole, ) +from llama_index.core.base.llms.generic_utils import ( + chat_to_completion_decorator, + stream_chat_to_completion_decorator, +) from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.callbacks import CallbackManager @@ -22,8 +26,7 @@ llm_chat_callback, llm_completion_callback, ) - -from llama_index.core.llms.llm import LLM +from llama_index.core.llms.function_calling import FunctionCallingLLM, ToolSelection from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.llms.oci_genai.utils import ( @@ -34,40 +37,17 @@ get_completion_generator, get_chat_generator, get_context_size, + _format_oci_tool_calls, + force_single_tool_call, + validate_tool_call, ) +if TYPE_CHECKING: + from llama_index.core.tools.types import BaseTool -# TODO: -# (1) placeholder for future LLMs in utils.py e.g., llama3, command R+ -class OCIGenAI(LLM): - """OCI large language models. - - To authenticate, the OCI client uses the methods described in - https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm - - The authentifcation method is passed through auth_type and should be one of: - API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL - Make sure you have the required policies (profile/roles) to - access the OCI Generative AI service. - If a specific config profile is used, you must pass - the name of the profile (from ~/.oci/config) through auth_profile. - - To use, you must provide the compartment id - along with the endpoint url, and model id - as named parameters to the constructor. - - Example: - .. code-block:: python - - from llama_index.llms.oci_genai import OCIGenAI - - llm = OCIGenAI( - model="MY_MODEL_ID", - service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", - compartment_id="MY_OCID" - ) - """ +class OCIGenAI(FunctionCallingLLM): + """OCI large language models with function calling support.""" model: str = Field(description="Id of the OCI Generative AI model to use.") temperature: float = Field(description="The temperature to use for sampling.") @@ -190,7 +170,6 @@ def __init__( @classmethod def class_name(cls) -> str: - """Get class name.""" return "OCIGenAI_LLM" @property @@ -223,57 +202,27 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: def complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponse: - inference_params = self._get_all_kwargs(**kwargs) - inference_params["is_stream"] = False - inference_params["prompt"] = prompt - - request = self._completion_generator( - compartment_id=self.compartment_id, - serving_mode=self._serving_mode, - inference_request=self._provider.oci_completion_request(**inference_params), - ) - - response = self._client.generate_text(request) - return CompletionResponse( - text=self._provider.completion_response_to_text(response), - raw=response.__dict__, - ) + complete_fn = chat_to_completion_decorator(self.chat) + return complete_fn(prompt, **kwargs) @llm_completion_callback() def stream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseGen: - inference_params = self._get_all_kwargs(**kwargs) - inference_params["is_stream"] = True - inference_params["prompt"] = prompt - - request = self._completion_generator( - compartment_id=self.compartment_id, - serving_mode=self._serving_mode, - inference_request=self._provider.oci_completion_request(**inference_params), - ) - - response = self._client.generate_text(request) - - def gen() -> CompletionResponseGen: - content = "" - for event in response.data.events(): - content_delta = self._provider.completion_stream_to_text( - json.loads(event.data) - ) - content += content_delta - yield CompletionResponse( - text=content, delta=content_delta, raw=event.__dict__ - ) - - return gen() + stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat) + return stream_complete_fn(prompt, **kwargs) @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: oci_params = self._provider.messages_to_oci_params(messages) oci_params["is_stream"] = False + tools = kwargs.pop("tools", None) all_kwargs = self._get_all_kwargs(**kwargs) chat_params = {**all_kwargs, **oci_params} + if tools: + chat_params["tools"] = [ + self._provider.convert_to_oci_tool(tool) for tool in tools + ] request = self._chat_generator( compartment_id=self.compartment_id, @@ -283,10 +232,20 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: response = self._client.chat(request) + generation_info = self._provider.chat_generation_info(response) + + llm_output = { + "model_id": response.data.model_id, + "model_version": response.data.model_version, + "request_id": response.request_id, + "content-length": response.headers["content-length"], + } + return ChatResponse( message=ChatMessage( role=MessageRole.ASSISTANT, content=self._provider.chat_response_to_text(response), + additional_kwargs=generation_info, ), raw=response.__dict__, ) @@ -296,8 +255,13 @@ def stream_chat( ) -> ChatResponseGen: oci_params = self._provider.messages_to_oci_params(messages) oci_params["is_stream"] = True + tools = kwargs.pop("tools", None) all_kwargs = self._get_all_kwargs(**kwargs) chat_params = {**all_kwargs, **oci_params} + if tools: + chat_params["tools"] = [ + self._provider.convert_to_oci_tool(tool) for tool in tools + ] request = self._chat_generator( compartment_id=self.compartment_id, @@ -309,39 +273,169 @@ def stream_chat( def gen() -> ChatResponseGen: content = "" + tool_calls_accumulated = [] + for event in response.data.events(): content_delta = self._provider.chat_stream_to_text( json.loads(event.data) ) content += content_delta - yield ChatResponse( - message=ChatMessage(role=MessageRole.ASSISTANT, content=content), - delta=content_delta, - raw=event.__dict__, - ) - return gen() + try: + event_data = json.loads(event.data) + + tool_calls_data = None + for key in ["toolCalls", "tool_calls", "functionCalls"]: + if key in event_data: + tool_calls_data = event_data[key] + break + + if tool_calls_data: + new_tool_calls = _format_oci_tool_calls(tool_calls_data) + for tool_call in new_tool_calls: + existing = next( + ( + t + for t in tool_calls_accumulated + if t["name"] == tool_call["name"] + ), + None, + ) + if existing: + existing.update(tool_call) + else: + tool_calls_accumulated.append(tool_call) + + generation_info = self._provider.chat_stream_generation_info( + event_data + ) + if tool_calls_accumulated: + generation_info["tool_calls"] = tool_calls_accumulated + + yield ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content=content, + additional_kwargs=generation_info, + ), + delta=content_delta, + raw=event.__dict__, + ) + + except json.JSONDecodeError: + yield ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, content=content + ), + delta=content_delta, + raw=event.__dict__, + ) + + except Exception as e: + print(f"Error processing stream chunk: {e}") + yield ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, content=content + ), + delta=content_delta, + raw=event.__dict__, + ) - async def acomplete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponse: - # do synchronous complete for now - return self.complete(prompt, formatted=formatted, **kwargs) + return gen() async def achat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponse: - # do synchronous chat for now - return self.chat(messages, **kwargs) + raise NotImplementedError("Async chat is not implemented yet.") + + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + raise NotImplementedError("Async complete is not implemented yet.") async def astream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseAsyncGen: - # do synchronous stream chat for now - return self.stream_chat(messages, **kwargs) + raise NotImplementedError("Async stream chat is not implemented yet.") async def astream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseAsyncGen: - # do synchronous stream complete for now - return self.stream_complete(prompt, formatted, **kwargs) + raise NotImplementedError("Async stream complete is not implemented yet.") + + # Function tooling integration methods + def _prepare_chat_with_tools( + self, + tools: Sequence["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + tool_specs = [self._provider.convert_to_oci_tool(tool) for tool in tools] + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + oci_params = self._provider.messages_to_oci_params(messages) + chat_params = self._get_all_kwargs(**kwargs) + + return { + "messages": messages, + "tools": tool_specs, + **oci_params, + **chat_params, + } + + def _validate_chat_with_tools_response( + self, + response: ChatResponse, + tools: List["BaseTool"], + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """Validate the response from chat_with_tools.""" + if not allow_parallel_tool_calls: + force_single_tool_call(response) + return response + + def get_tool_calls_from_response( + self, + response: "ChatResponse", + error_on_no_tool_call: bool = True, + **kwargs: Any, + ) -> List[ToolSelection]: + """Predict and call the tool.""" + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + + if len(tool_calls) < 1: + if error_on_no_tool_call: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + else: + return [] + + tool_selections = [] + for tool_call in tool_calls: + validate_tool_call(tool_call) + argument_dict = ( + json.loads(tool_call["input"]) + if isinstance(tool_call["input"], str) + else tool_call["input"] + ) + + tool_selections.append( + ToolSelection( + tool_id=tool_call["toolUseId"], + tool_name=tool_call["name"], + tool_kwargs=argument_dict, + ) + ) + + return tool_selections diff --git a/llama-index-integrations/llms/llama-index-llms-oci-genai/llama_index/llms/oci_genai/utils.py b/llama-index-integrations/llms/llama-index-llms-oci-genai/llama_index/llms/oci_genai/utils.py index 4cf47c150c586..ee440c3b8b915 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-genai/llama_index/llms/oci_genai/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-genai/llama_index/llms/oci_genai/utils.py @@ -1,7 +1,13 @@ +import inspect +import json +import uuid from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Sequence, Dict -from llama_index.core.base.llms.types import ChatMessage +from typing import Any, Sequence, Dict, Union, Type, Callable, Optional, List +from llama_index.core.base.llms.types import ChatMessage, MessageRole, ChatResponse +from llama_index.core.bridge.pydantic import BaseModel +from llama_index.core.tools import BaseTool +from oci.generative_ai_inference.models import CohereTool class OCIAuthType(Enum): @@ -15,30 +21,69 @@ class OCIAuthType(Enum): CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" -COMPLETION_MODELS = { - "cohere.command": 4096, - "cohere.command-light": 4096, - "meta.llama-2-70b-chat": 4096, -} +COMPLETION_MODELS = {} # completion endpoint has been deprecated CHAT_MODELS = { "cohere.command-r-16k": 16000, - "cohere.command-r-plus": 128000, # placeholder for future support + "cohere.command-r-plus": 128000, + "cohere.command-r-08-2024": 128000, + "cohere.command-r-plus-08-2024": 128000, "meta.llama-3-70b-instruct": 8192, + "meta.llama-3.1-70b-instruct": 128000, + "meta.llama-3.1-405b-instruct": 128000, + "meta.llama-3.2-90b-vision-instruct": 128000, } OCIGENAI_LLMS = {**COMPLETION_MODELS, **CHAT_MODELS} -STREAMING_MODELS = { - "cohere.command", - "cohere.command-light", - "meta.llama-2-70b-chat", - "cohere.command-r-16k", - "cohere.command-r-plus", - "meta.llama-3-70b-instruct", +JSON_TO_PYTHON_TYPES = { + "string": "str", + "number": "float", + "boolean": "bool", + "integer": "int", + "array": "List", + "object": "Dict", } +def _format_oci_tool_calls( + tool_calls: Optional[List[Any]] = None, +) -> List[Dict]: + """ + Formats an OCI GenAI API response into the tool call format used in LlamaIndex. + Handles both dictionary and object formats. + """ + if not tool_calls: + return [] + + formatted_tool_calls = [] + for tool_call in tool_calls: + # Handle both object and dict formats + if isinstance(tool_call, dict): + name = tool_call.get("name", tool_call.get("functionName")) + parameters = tool_call.get( + "parameters", tool_call.get("functionParameters") + ) + else: + name = getattr(tool_call, "name", getattr(tool_call, "functionName", None)) + parameters = getattr( + tool_call, "parameters", getattr(tool_call, "functionParameters", None) + ) + + if name and parameters: + formatted_tool_calls.append( + { + "toolUseId": uuid.uuid4().hex[:], + "name": name, + "input": json.dumps(parameters) + if isinstance(parameters, dict) + else parameters, + } + ) + + return formatted_tool_calls + + def create_client(auth_type, auth_profile, service_endpoint): """OCI Gen AI client. @@ -166,10 +211,25 @@ def chat_response_to_text(self, response: Any) -> str: def chat_stream_to_text(self, event_data: Dict) -> str: ... + @abstractmethod + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + ... + + @abstractmethod + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + ... + @abstractmethod def messages_to_oci_params(self, messages: Sequence[ChatMessage]) -> Dict[str, Any]: ... + @abstractmethod + def convert_to_oci_tool( + self, + tool: Union[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + ) -> Dict[str, Any]: + ... + class CohereProvider(Provider): def __init__(self) -> None: @@ -184,10 +244,14 @@ def __init__(self) -> None: self.oci_completion_request = models.CohereLlmInferenceRequest self.oci_chat_request = models.CohereChatRequest + self.oci_tool = models.CohereTool + self.oci_tool_param = models.CohereParameterDefinition + self.oci_tool_result = models.CohereToolResult + self.oci_tool_call = models.CohereToolCall self.oci_chat_message = { "USER": models.CohereUserMessage, - "SYSTEM": models.CohereSystemMessage, "CHATBOT": models.CohereChatBotMessage, + "SYSTEM": models.CohereSystemMessage, "TOOL": models.CohereToolMessage, } self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE @@ -202,31 +266,274 @@ def chat_response_to_text(self, response: Any) -> str: return response.data.chat_response.text def chat_stream_to_text(self, event_data: Dict) -> str: - if "text" in event_data and "finishReason" not in event_data: - return event_data["text"] - else: - return "" + if "text" in event_data: + if "finishedReason" in event_data or "toolCalls" in event_data: + return "" + else: + return event_data["text"] + return "" + + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + generation_info: Dict[str, Any] = { + "finish_reason": response.data.chat_response.finish_reason, + "documents": response.data.chat_response.documents, + "citations": response.data.chat_response.citations, + "search_queries": response.data.chat_response.search_queries, + "is_search_required": response.data.chat_response.is_search_required, + } + if response.data.chat_response.tool_calls: + # Only populate tool_calls when 1) present on the response and + # 2) has one or more calls. + generation_info["tool_calls"] = _format_oci_tool_calls( + response.data.chat_response.tool_calls + ) + + return generation_info + + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + """Extract generation info from a streaming chat response.""" + generation_info: Dict[str, Any] = { + "finish_reason": event_data.get("finishReason"), + "documents": event_data.get("documents", []), + "citations": event_data.get("citations", []), + "search_queries": event_data.get("searchQueries", []), + "is_search_required": event_data.get("isSearchRequired", False), + } + + # Handle tool calls if present + if "toolCalls" in event_data: + generation_info["tool_calls"] = _format_oci_tool_calls( + event_data["toolCalls"] + ) + + return {k: v for k, v in generation_info.items() if v is not None} def messages_to_oci_params(self, messages: Sequence[ChatMessage]) -> Dict[str, Any]: role_map = { "user": "USER", "system": "SYSTEM", - "chatbot": "CHATBOT", "assistant": "CHATBOT", "tool": "TOOL", + "function": "TOOL", + "chatbot": "CHATBOT", + "model": "CHATBOT", } - oci_chat_history = [ - self.oci_chat_message[role_map[msg.role]](message=msg.content) - for msg in messages[:-1] - ] + oci_chat_history = [] + + for msg in messages[:-1]: + role = role_map[msg.role.value] + + # Handle tool calls for AI/Assistant messages + if role == "CHATBOT" and "tool_calls" in msg.additional_kwargs: + tool_calls = [] + for tool_call in msg.additional_kwargs.get("tool_calls", []): + validate_tool_call(tool_call) + tool_calls.append( + self.oci_tool_call( + name=tool_call.get("name"), + parameters=json.loads(tool_call["input"]) + if isinstance(tool_call["input"], str) + else tool_call["input"], + ) + ) + + oci_chat_history.append( + self.oci_chat_message[role]( + message=msg.content if msg.content else " ", + tool_calls=tool_calls if tool_calls else None, + ) + ) + elif role == "TOOL": + # tool message only has tool results field and no message field + continue + else: + oci_chat_history.append( + self.oci_chat_message[role](message=msg.content or " ") + ) - return { - "message": messages[-1].content, + # Handling the current chat turn, especially the latest message + current_chat_turn_messages = [] + for message in messages[::-1]: + current_chat_turn_messages.append(message) + if message.role == MessageRole.USER: + break + current_chat_turn_messages = current_chat_turn_messages[::-1] + + oci_tool_results = [] + for message in current_chat_turn_messages: + if message.role == MessageRole.TOOL: + tool_message = message + previous_ai_msgs = [ + message + for message in current_chat_turn_messages + if message.role == MessageRole.ASSISTANT + and "tool_calls" in message.additional_kwargs + ] + if previous_ai_msgs: + previous_ai_msg = previous_ai_msgs[-1] + for li_tool_call in previous_ai_msg.additional_kwargs.get( + "tool_calls", [] + ): + validate_tool_call(li_tool_call) + if li_tool_call[ + "toolUseId" + ] == tool_message.additional_kwargs.get("tool_call_id"): + tool_result = self.oci_tool_result() + tool_result.call = self.oci_tool_call( + name=li_tool_call.get("name"), + parameters=json.loads(li_tool_call["input"]) + if isinstance(li_tool_call["input"], str) + else li_tool_call["input"], + ) + tool_result.outputs = [{"output": tool_message.content}] + oci_tool_results.append(tool_result) + + if not oci_tool_results: + oci_tool_results = None + + message_str = "" if oci_tool_results or not messages else messages[-1].content + + oci_params = { + "message": message_str, "chat_history": oci_chat_history, + "tool_results": oci_tool_results, "api_format": self.chat_api_format, } + return {k: v for k, v in oci_params.items() if v is not None} + + def convert_to_oci_tool( + self, + tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], + ) -> CohereTool: + """ + Convert a Pydantic class, JSON schema dict, callable, or BaseTool to a CohereTool format for OCI. + + Args: + tool: The tool to convert, which can be a Pydantic class, a callable, or a JSON schema dictionary. + + Returns: + A CohereTool representing the tool in the OCI API format. + """ + if isinstance(tool, BaseTool): + # Extract tool name and description for BaseTool + tool_name, tool_description = getattr(tool, "name", None), getattr( + tool, "description", None + ) + if not tool_name or not tool_description: + tool_name = getattr(tool.metadata, "name", None) + if tool_fn := getattr(tool, "fn", None): + tool_description = tool_fn.__doc__ + if not tool_name: + tool_name = tool_fn.__name__ + else: + tool_description = getattr(tool.metadata, "description", None) + if not tool_name or not tool_description: + raise ValueError( + f"Tool {tool} does not have a name or description." + ) + + return self.oci_tool( + name=tool_name, + description=tool_description, + parameter_definitions={ + p_name: self.oci_tool_param( + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), p_def.get("type") + ), + description=p_def.get("description", ""), + is_required=p_name + in tool.metadata.get_parameters_dict().get("required", []), + ) + for p_name, p_def in tool.metadata.get_parameters_dict() + .get("properties", {}) + .items() + }, + ) + + elif isinstance(tool, dict): + # Ensure dict-based tools follow a standard schema format + if not all(k in tool for k in ("title", "description", "properties")): + raise ValueError( + "Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." + ) + return self.oci_tool( + name=tool.get("title"), + description=tool.get("description"), + parameter_definitions={ + p_name: self.oci_tool_param( + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), p_def.get("type") + ), + description=p_def.get("description", ""), + is_required=p_name in tool.get("required", []), + ) + for p_name, p_def in tool.get("properties", {}).items() + }, + ) + + elif isinstance(tool, type) and issubclass(tool, BaseModel): + # Handle Pydantic BaseModel tools + schema = tool.model_json_schema() + properties = schema.get("properties", {}) + return self.oci_tool( + name=schema.get("title", tool.__name__), + description=schema.get("description", tool.__name__), + parameter_definitions={ + p_name: self.oci_tool_param( + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), p_def.get("type") + ), + description=p_def.get("description", ""), + is_required=p_name in schema.get("required", []), + ) + for p_name, p_def in properties.items() + }, + ) + + elif callable(tool): + # Use inspect to extract callable signature and arguments + signature = inspect.signature(tool) + parameters = {} + for param_name, param in signature.parameters.items(): + param_type = ( + param.annotation if param.annotation != inspect._empty else "string" + ) + param_default = ( + param.default if param.default != inspect._empty else None + ) + + # Convert type to JSON schema type (or leave as default) + json_type = JSON_TO_PYTHON_TYPES.get( + param_type, + param_type.__name__ if isinstance(param_type, type) else "string", + ) + + parameters[param_name] = { + "type": json_type, + "description": f"Parameter: {param_name}", + "is_required": param_default is None, + } + + return self.oci_tool( + name=tool.__name__, + description=tool.__doc__ or f"Callable function: {tool.__name__}", + parameter_definitions={ + param_name: self.oci_tool_param( + type=param_data["type"], + description=param_data["description"], + is_required=param_data["is_required"], + ) + for param_name, param_data in parameters.items() + }, + ) + + else: + raise ValueError( + f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." + ) + class MetaProvider(Provider): def __init__(self) -> None: @@ -264,6 +571,17 @@ def chat_stream_to_text(self, event_data: Dict) -> str: else: return "" + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + return { + "finish_reason": response.data.chat_response.choices[0].finish_reason, + "time_created": str(response.data.chat_response.time_created), + } + + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + return { + "finish_reason": event_data["finishReason"], + } + def messages_to_oci_params(self, messages: Sequence[ChatMessage]) -> Dict[str, Any]: role_map = { "user": "USER", @@ -285,6 +603,14 @@ def messages_to_oci_params(self, messages: Sequence[ChatMessage]) -> Dict[str, A "top_k": -1, } + def convert_to_oci_tool( + self, + tool: Union[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + ) -> Dict[str, Any]: + raise NotImplementedError( + "Tools not supported for OCI Generative AI Meta models" + ) + PROVIDERS = { "cohere": CohereProvider(), @@ -327,3 +653,18 @@ def get_context_size(model: str, context_size: int = None) -> int: ) from e else: return context_size + + +def validate_tool_call(tool_call: Dict[str, Any]): + if ( + "input" not in tool_call + or "toolUseId" not in tool_call + or "name" not in tool_call + ): + raise ValueError("Invalid tool call.") + + +def force_single_tool_call(response: ChatResponse) -> None: + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + if len(tool_calls) > 1: + response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] diff --git a/llama-index-integrations/llms/llama-index-llms-oci-genai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-oci-genai/pyproject.toml index 339543872b5bc..d17ae35716acc 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-genai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-oci-genai/pyproject.toml @@ -31,11 +31,11 @@ license = "MIT" name = "llama-index-llms-oci-genai" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.3.0" +version = "0.4.0" [tool.poetry.dependencies] -python = ">=3.9,<4.0" -oci = "^2.128.0" +python = ">=3.8.1,<4.0" +oci = "^2.134.0" llama-index-core = "^0.12.0" [tool.poetry.group.dev.dependencies] diff --git a/llama-index-integrations/llms/llama-index-llms-oci-genai/tests/test_oci_genai.py b/llama-index-integrations/llms/llama-index-llms-oci-genai/tests/test_oci_genai.py index 32c9cc1cbba32..43b841e6403d3 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-genai/tests/test_oci_genai.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-genai/tests/test_oci_genai.py @@ -1,4 +1,5 @@ -###Test OCI Generative AI LLM service +# Test OCI Generative AI LLM service + from unittest.mock import MagicMock from typing import Any @@ -7,6 +8,8 @@ from llama_index.llms.oci_genai import OCIGenAI from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole +from llama_index.core.tools import FunctionTool +import json class MockResponseDict(dict): @@ -14,9 +17,7 @@ def __getattr__(self, val) -> Any: # type: ignore[no-untyped-def] return self[val] -@pytest.mark.parametrize( - "test_model_id", ["cohere.command", "cohere.command-light", "meta.llama-2-70b-chat"] -) +@pytest.mark.parametrize("test_model_id", []) def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None: """Test valid completion call to OCI Generative AI LLM service.""" oci_gen_ai_client = MagicMock() @@ -79,7 +80,13 @@ def mocked_response(*args): # type: ignore[no-untyped-def] @pytest.mark.parametrize( - "test_model_id", ["cohere.command-r-16k", "meta.llama-3-70b-instruct"] + "test_model_id", + [ + "cohere.command-r-16k", + "cohere.command-r-plus", + "meta.llama-3-70b-instruct", + "meta.llama-3.1-70b-instruct", + ], ) def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None: """Test valid chat call to OCI Generative AI LLM service.""" @@ -100,10 +107,20 @@ def mocked_response(*args): # type: ignore[no-untyped-def] "chat_response": MockResponseDict( { "text": response_text, + "finish_reason": "stop", + "documents": [], + "citations": [], + "search_queries": [], + "is_search_required": False, + "tool_calls": None, } - ) + ), + "model_id": "cohere.command-r-16k", + "model_version": "1.0", } ), + "request_id": "req-1234567890", + "headers": {"content-length": "1234"}, } ) elif provider == "MetaProvider": @@ -127,14 +144,20 @@ def mocked_response(*args): # type: ignore[no-untyped-def] ) ] } - ) + ), + "finish_reason": "stop", } ) - ] + ], + "time_created": "2024-11-03T12:00:00Z", } - ) + ), + "model_id": "meta.llama-3-70b-instruct", + "model_version": "1.0", } ), + "request_id": "req-0987654321", + "headers": {"content-length": "1234"}, } ) return response @@ -145,12 +168,161 @@ def mocked_response(*args): # type: ignore[no-untyped-def] ChatMessage(role="user", content="User message"), ] + # For Meta provider, we expect fewer fields in additional_kwargs + if provider == "MetaProvider": + additional_kwargs = { + "finish_reason": "stop", + "time_created": "2024-11-03T12:00:00Z", + } + else: + additional_kwargs = { + "finish_reason": "stop", + "documents": [], + "citations": [], + "search_queries": [], + "is_search_required": False, + } + expected = ChatResponse( message=ChatMessage( - role=MessageRole.ASSISTANT, content="Assistant chat reply." + role=MessageRole.ASSISTANT, + content="Assistant chat reply.", + additional_kwargs=additional_kwargs, ), - raw=llm._client.chat.__dict__, + raw={}, # Mocked raw data + additional_kwargs={ + "model_id": test_model_id, + "model_version": "1.0", + "request_id": "req-1234567890" + if test_model_id == "cohere.command-r-16k" + else "req-0987654321", + "content-length": "1234", + }, ) actual = llm.chat(messages, temperature=0.2) - assert actual == expected + assert actual.message.content == expected.message.content + + +@pytest.mark.parametrize( + "test_model_id", ["cohere.command-r-16k", "cohere.command-r-plus"] +) +def test_llm_chat_with_tools(monkeypatch: MonkeyPatch, test_model_id: str) -> None: + """Test chat_with_tools call to OCI Generative AI LLM service with tool calling.""" + oci_gen_ai_client = MagicMock() + llm = OCIGenAI(model=test_model_id, client=oci_gen_ai_client) + + provider = llm._provider.__class__.__name__ + + def mock_tool_function(param1: str) -> str: + """Mock tool function that takes a string parameter.""" + return f"Mock tool function called with {param1}" + + # Create proper FunctionTool + mock_tool = FunctionTool.from_defaults(fn=mock_tool_function) + tools = [mock_tool] + + messages = [ + ChatMessage(role="user", content="User message"), + ] + + # Mock the client response + def mocked_response(*args, **kwargs): + response_text = "Assistant chat reply." + tool_calls = [ + MockResponseDict( + { + "name": "mock_tool_function", + "parameters": {"param1": "test"}, + } + ) + ] + response = None + if provider == "CohereProvider": + response = MockResponseDict( + { + "status": 200, + "data": MockResponseDict( + { + "chat_response": MockResponseDict( + { + "text": response_text, + "finish_reason": "stop", + "documents": [], + "citations": [], + "search_queries": [], + "is_search_required": False, + "tool_calls": tool_calls, + } + ), + "model_id": test_model_id, + "model_version": "1.0", + } + ), + "request_id": "req-1234567890", + "headers": {"content-length": "1234"}, + } + ) + else: + # MetaProvider does not support tools + raise NotImplementedError("Tools not supported for this provider.") + return response + + monkeypatch.setattr(llm._client, "chat", mocked_response) + + actual_response = llm.chat( + messages=messages, + tools=tools, + ) + + # Expected response structure + expected_tool_calls = [ + { + "name": "mock_tool_function", + "toolUseId": actual_response.message.additional_kwargs["tool_calls"][0][ + "toolUseId" + ], + "input": json.dumps({"param1": "test"}), + } + ] + + expected_response = ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content="Assistant chat reply.", + additional_kwargs={ + "finish_reason": "stop", + "documents": [], + "citations": [], + "search_queries": [], + "is_search_required": False, + "tool_calls": expected_tool_calls, + }, + ), + raw={}, + ) + + # Compare everything except the toolUseId which is randomly generated + assert actual_response.message.role == expected_response.message.role + assert actual_response.message.content == expected_response.message.content + + actual_kwargs = actual_response.message.additional_kwargs + expected_kwargs = expected_response.message.additional_kwargs + + # Check all non-tool_calls fields + for key in [k for k in expected_kwargs if k != "tool_calls"]: + assert actual_kwargs[key] == expected_kwargs[key] + + # Check tool calls separately + actual_tool_calls = actual_kwargs["tool_calls"] + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tc, expected_tc in zip(actual_tool_calls, expected_tool_calls): + assert actual_tc["name"] == expected_tc["name"] + assert actual_tc["input"] == expected_tc["input"] + assert "toolUseId" in actual_tc + assert isinstance(actual_tc["toolUseId"], str) + assert len(actual_tc["toolUseId"]) > 0 + + # Check additional_kwargs + assert actual_response.additional_kwargs == expected_response.additional_kwargs