From 4e76768bc9b79b0e53f1fd24c8bbcddddc1a1cbd Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Mon, 3 Jun 2024 01:29:24 +0200 Subject: [PATCH] refactor(forge/llm): Create `BaseOpenAIProvider` -> deduplicate `GroqProvider` & `OpenAIProvider` implementation (#7178) - Add `_BaseOpenAIProvider`, `BaseOpenAIChatProvider`, and `BaseOpenAIEmbeddingProvider`, which implement the shared functionality of OpenAI-like providers, e.g. `GroqProvider` and `OpenAIProvider` - (Re)move as much code as possible from `GroqProvider` and `OpenAIProvider` by rebasing them on `BaseOpenAI(Chat|Embedding)Provider` Also: - Rename `get_available_models()` to `get_available_chat_models()` on `BaseChatModelProvider` - Add `get_available_models()` to `BaseModelProvider` - Add `get_available_embedding_models()` to `BaseEmbeddingModelProvider` - Move common `fix_failed_parse_tries` config attribute into base `ModelProviderConfiguration` --- autogpt/autogpt/app/configurator.py | 2 +- autogpt/tests/unit/test_config.py | 4 +- forge/forge/llm/providers/_openai_base.py | 514 ++++++++++++++++++++++ forge/forge/llm/providers/anthropic.py | 25 +- forge/forge/llm/providers/groq.py | 329 +------------- forge/forge/llm/providers/multi.py | 10 +- forge/forge/llm/providers/openai.py | 457 ++++--------------- forge/forge/llm/providers/schema.py | 31 +- 8 files changed, 640 insertions(+), 732 deletions(-) create mode 100644 forge/forge/llm/providers/_openai_base.py diff --git a/autogpt/autogpt/app/configurator.py b/autogpt/autogpt/app/configurator.py index 1b54405157cf..8e2f2f3eb88b 100644 --- a/autogpt/autogpt/app/configurator.py +++ b/autogpt/autogpt/app/configurator.py @@ -103,7 +103,7 @@ async def check_model( ) -> ModelName: """Check if model is available for use. If not, return gpt-3.5-turbo.""" multi_provider = MultiProvider() - models = await multi_provider.get_available_models() + models = await multi_provider.get_available_chat_models() if any(model_name == m.name for m in models): return model_name diff --git a/autogpt/tests/unit/test_config.py b/autogpt/tests/unit/test_config.py index 73c2537b7aa9..aa63b97664de 100644 --- a/autogpt/tests/unit/test_config.py +++ b/autogpt/tests/unit/test_config.py @@ -144,7 +144,7 @@ def test_azure_config(config_with_azure: Config) -> None: @pytest.mark.asyncio async def test_create_config_gpt4only(config: Config) -> None: with mock.patch( - "forge.llm.providers.multi.MultiProvider.get_available_models" + "forge.llm.providers.multi.MultiProvider.get_available_chat_models" ) as mock_get_models: mock_get_models.return_value = [ ChatModelInfo( @@ -164,7 +164,7 @@ async def test_create_config_gpt4only(config: Config) -> None: @pytest.mark.asyncio async def test_create_config_gpt3only(config: Config) -> None: with mock.patch( - "forge.llm.providers.multi.MultiProvider.get_available_models" + "forge.llm.providers.multi.MultiProvider.get_available_chat_models" ) as mock_get_models: mock_get_models.return_value = [ ChatModelInfo( diff --git a/forge/forge/llm/providers/_openai_base.py b/forge/forge/llm/providers/_openai_base.py new file mode 100644 index 000000000000..852876286b3a --- /dev/null +++ b/forge/forge/llm/providers/_openai_base.py @@ -0,0 +1,514 @@ +import logging +from typing import ( + Any, + Awaitable, + Callable, + ClassVar, + Mapping, + Optional, + ParamSpec, + Sequence, + TypeVar, + cast, +) + +import sentry_sdk +import tenacity +from openai._exceptions import APIConnectionError, APIStatusError +from openai.types import CreateEmbeddingResponse, EmbeddingCreateParams +from openai.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionMessage, + ChatCompletionMessageParam, + CompletionCreateParams, +) +from openai.types.shared_params import FunctionDefinition + +from forge.json.parsing import json_loads + +from .schema import ( + AssistantChatMessage, + AssistantFunctionCall, + AssistantToolCall, + BaseChatModelProvider, + BaseEmbeddingModelProvider, + BaseModelProvider, + ChatMessage, + ChatModelInfo, + ChatModelResponse, + CompletionModelFunction, + Embedding, + EmbeddingModelInfo, + EmbeddingModelResponse, + ModelProviderService, + _ModelName, + _ModelProviderSettings, +) +from .utils import validate_tool_calls + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +class _BaseOpenAIProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]): + """Base class for LLM providers with OpenAI-like APIs""" + + MODELS: ClassVar[ + Mapping[_ModelName, ChatModelInfo[_ModelName] | EmbeddingModelInfo[_ModelName]] # type: ignore # noqa + ] + + def __init__( + self, + settings: Optional[_ModelProviderSettings] = None, + logger: Optional[logging.Logger] = None, + ): + if not getattr(self, "MODELS", None): + raise ValueError(f"{self.__class__.__name__}.MODELS is not set") + + if not settings: + settings = self.default_settings.copy(deep=True) + if not settings.credentials: + settings.credentials = self.default_settings.__fields__[ + "credentials" + ].type_.from_env() + + super(_BaseOpenAIProvider, self).__init__(settings=settings, logger=logger) + + if not getattr(self, "_client", None): + from openai import AsyncOpenAI + + self._client = AsyncOpenAI( + **self._credentials.get_api_access_kwargs() # type: ignore + ) + + async def get_available_models( + self, + ) -> Sequence[ChatModelInfo[_ModelName] | EmbeddingModelInfo[_ModelName]]: + _models = (await self._client.models.list()).data + return [ + self.MODELS[cast(_ModelName, m.id)] for m in _models if m.id in self.MODELS + ] + + def get_token_limit(self, model_name: _ModelName) -> int: + """Get the maximum number of input tokens for a given model""" + return self.MODELS[model_name].max_tokens + + def count_tokens(self, text: str, model_name: _ModelName) -> int: + return len(self.get_tokenizer(model_name).encode(text)) + + def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]: + return tenacity.retry( + retry=( + tenacity.retry_if_exception_type(APIConnectionError) + | tenacity.retry_if_exception( + lambda e: isinstance(e, APIStatusError) and e.status_code >= 500 + ) + ), + wait=tenacity.wait_exponential(), + stop=tenacity.stop_after_attempt(self._configuration.retries_per_request), + after=tenacity.after_log(self._logger, logging.DEBUG), + )(func) + + def __repr__(self): + return f"{self.__class__.__name__}()" + + +class BaseOpenAIChatProvider( + _BaseOpenAIProvider[_ModelName, _ModelProviderSettings], + BaseChatModelProvider[_ModelName, _ModelProviderSettings], +): + CHAT_MODELS: ClassVar[dict[_ModelName, ChatModelInfo[_ModelName]]] # type: ignore + + def __init__( + self, + settings: Optional[_ModelProviderSettings] = None, + logger: Optional[logging.Logger] = None, + ): + if not getattr(self, "CHAT_MODELS", None): + raise ValueError(f"{self.__class__.__name__}.CHAT_MODELS is not set") + + super(BaseOpenAIChatProvider, self).__init__(settings=settings, logger=logger) + + async def get_available_chat_models(self) -> Sequence[ChatModelInfo[_ModelName]]: + all_available_models = await self.get_available_models() + return [ + model + for model in all_available_models + if model.service == ModelProviderService.CHAT + ] + + def count_message_tokens( + self, + messages: ChatMessage | list[ChatMessage], + model_name: _ModelName, + ) -> int: + if isinstance(messages, ChatMessage): + messages = [messages] + return self.count_tokens( + "\n\n".join(f"{m.role.upper()}: {m.content}" for m in messages), model_name + ) + + async def create_chat_completion( + self, + model_prompt: list[ChatMessage], + model_name: _ModelName, + completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, + functions: Optional[list[CompletionModelFunction]] = None, + max_output_tokens: Optional[int] = None, + prefill_response: str = "", + **kwargs, + ) -> ChatModelResponse[_T]: + """Create a chat completion using the API.""" + + ( + openai_messages, + completion_kwargs, + parse_kwargs, + ) = self._get_chat_completion_args( + prompt_messages=model_prompt, + model=model_name, + functions=functions, + max_output_tokens=max_output_tokens, + **kwargs, + ) + + total_cost = 0.0 + attempts = 0 + while True: + completion_kwargs["messages"] = openai_messages + _response, _cost, t_input, t_output = await self._create_chat_completion( + model=model_name, + completion_kwargs=completion_kwargs, + ) + total_cost += _cost + + # If parsing the response fails, append the error to the prompt, and let the + # LLM fix its mistake(s). + attempts += 1 + parse_errors: list[Exception] = [] + + _assistant_msg = _response.choices[0].message + + tool_calls, _errors = self._parse_assistant_tool_calls( + _assistant_msg, **parse_kwargs + ) + parse_errors += _errors + + # Validate tool calls + if not parse_errors and tool_calls and functions: + parse_errors += validate_tool_calls(tool_calls, functions) + + assistant_msg = AssistantChatMessage( + content=_assistant_msg.content or "", + tool_calls=tool_calls or None, + ) + + parsed_result: _T = None # type: ignore + if not parse_errors: + try: + parsed_result = completion_parser(assistant_msg) + except Exception as e: + parse_errors.append(e) + + if not parse_errors: + if attempts > 1: + self._logger.debug( + f"Total cost for {attempts} attempts: ${round(total_cost, 5)}" + ) + + return ChatModelResponse( + response=AssistantChatMessage( + content=_assistant_msg.content or "", + tool_calls=tool_calls or None, + ), + parsed_result=parsed_result, + model_info=self.CHAT_MODELS[model_name], + prompt_tokens_used=t_input, + completion_tokens_used=t_output, + ) + + else: + self._logger.debug( + f"Parsing failed on response: '''{_assistant_msg}'''" + ) + parse_errors_fmt = "\n\n".join( + f"{e.__class__.__name__}: {e}" for e in parse_errors + ) + self._logger.warning( + f"Parsing attempt #{attempts} failed: {parse_errors_fmt}" + ) + for e in parse_errors: + sentry_sdk.capture_exception( + error=e, + extras={"assistant_msg": _assistant_msg, "i_attempt": attempts}, + ) + + if attempts < self._configuration.fix_failed_parse_tries: + openai_messages.append( + cast( + ChatCompletionAssistantMessageParam, + _assistant_msg.dict(exclude_none=True), + ) + ) + openai_messages.append( + { + "role": "system", + "content": ( + f"ERROR PARSING YOUR RESPONSE:\n\n{parse_errors_fmt}" + ), + } + ) + continue + else: + raise parse_errors[0] + + def _get_chat_completion_args( + self, + prompt_messages: list[ChatMessage], + model: _ModelName, + functions: Optional[list[CompletionModelFunction]] = None, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> tuple[ + list[ChatCompletionMessageParam], CompletionCreateParams, dict[str, Any] + ]: + """Prepare keyword arguments for a chat completion API call + + Args: + prompt_messages: List of ChatMessages + model: The model to use + functions (optional): List of functions available to the LLM + max_output_tokens (optional): Maximum number of tokens to generate + + Returns: + list[ChatCompletionMessageParam]: Prompt messages for the API call + CompletionCreateParams: Mapping of other kwargs for the API call + Mapping[str, Any]: Any keyword arguments to pass on to the completion parser + """ + kwargs = cast(CompletionCreateParams, kwargs) + + if max_output_tokens: + kwargs["max_tokens"] = max_output_tokens + + if functions: + kwargs["tools"] = [ # pyright: ignore - it fails to infer the dict type + {"type": "function", "function": format_function_def_for_openai(f)} + for f in functions + ] + if len(functions) == 1: + # force the model to call the only specified function + kwargs["tool_choice"] = { # pyright: ignore - type inference failure + "type": "function", + "function": {"name": functions[0].name}, + } + + if extra_headers := self._configuration.extra_request_headers: + # 'extra_headers' is not on CompletionCreateParams, but is on chat.create() + kwargs["extra_headers"] = kwargs.get("extra_headers", {}) # type: ignore + kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore + + prepped_messages: list[ChatCompletionMessageParam] = [ + message.dict( # type: ignore + include={"role", "content", "tool_calls", "tool_call_id", "name"}, + exclude_none=True, + ) + for message in prompt_messages + ] + + if "messages" in kwargs: + prepped_messages += kwargs["messages"] + del kwargs["messages"] # type: ignore - messages are added back later + + return prepped_messages, kwargs, {} + + async def _create_chat_completion( + self, + model: _ModelName, + completion_kwargs: CompletionCreateParams, + ) -> tuple[ChatCompletion, float, int, int]: + """ + Create a chat completion using an OpenAI-like API with retry handling + + Params: + model: The model to use for the completion + completion_kwargs: All other arguments for the completion call + + Returns: + ChatCompletion: The chat completion response object + float: The cost ($) of this completion + int: Number of prompt tokens used + int: Number of completion tokens used + """ + completion_kwargs["model"] = completion_kwargs.get("model") or model + + @self._retry_api_request + async def _create_chat_completion_with_retry() -> ChatCompletion: + return await self._client.chat.completions.create( + **completion_kwargs, # type: ignore + ) + + completion = await _create_chat_completion_with_retry() + + if completion.usage: + prompt_tokens_used = completion.usage.prompt_tokens + completion_tokens_used = completion.usage.completion_tokens + else: + prompt_tokens_used = completion_tokens_used = 0 + + if self._budget: + cost = self._budget.update_usage_and_cost( + model_info=self.CHAT_MODELS[model], + input_tokens_used=prompt_tokens_used, + output_tokens_used=completion_tokens_used, + ) + else: + cost = 0 + + self._logger.debug( + f"{model} completion usage: {prompt_tokens_used} input, " + f"{completion_tokens_used} output - ${round(cost, 5)}" + ) + return completion, cost, prompt_tokens_used, completion_tokens_used + + def _parse_assistant_tool_calls( + self, assistant_message: ChatCompletionMessage, **kwargs + ) -> tuple[list[AssistantToolCall], list[Exception]]: + tool_calls: list[AssistantToolCall] = [] + parse_errors: list[Exception] = [] + + if assistant_message.tool_calls: + for _tc in assistant_message.tool_calls: + try: + parsed_arguments = json_loads(_tc.function.arguments) + except Exception as e: + err_message = ( + f"Decoding arguments for {_tc.function.name} failed: " + + str(e.args[0]) + ) + parse_errors.append( + type(e)(err_message, *e.args[1:]).with_traceback( + e.__traceback__ + ) + ) + continue + + tool_calls.append( + AssistantToolCall( + id=_tc.id, + type=_tc.type, + function=AssistantFunctionCall( + name=_tc.function.name, + arguments=parsed_arguments, + ), + ) + ) + + # If parsing of all tool calls succeeds in the end, we ignore any issues + if len(tool_calls) == len(assistant_message.tool_calls): + parse_errors = [] + + return tool_calls, parse_errors + + +class BaseOpenAIEmbeddingProvider( + _BaseOpenAIProvider[_ModelName, _ModelProviderSettings], + BaseEmbeddingModelProvider[_ModelName, _ModelProviderSettings], +): + EMBEDDING_MODELS: ClassVar[ + dict[_ModelName, EmbeddingModelInfo[_ModelName]] # type: ignore + ] + + def __init__( + self, + settings: Optional[_ModelProviderSettings] = None, + logger: Optional[logging.Logger] = None, + ): + if not getattr(self, "EMBEDDING_MODELS", None): + raise ValueError(f"{self.__class__.__name__}.EMBEDDING_MODELS is not set") + + super(BaseOpenAIEmbeddingProvider, self).__init__( + settings=settings, logger=logger + ) + + async def get_available_embedding_models( + self, + ) -> Sequence[EmbeddingModelInfo[_ModelName]]: + all_available_models = await self.get_available_models() + return [ + model + for model in all_available_models + if model.service == ModelProviderService.EMBEDDING + ] + + async def create_embedding( + self, + text: str, + model_name: _ModelName, + embedding_parser: Callable[[Embedding], Embedding], + **kwargs, + ) -> EmbeddingModelResponse: + """Create an embedding using an OpenAI-like API""" + embedding_kwargs = self._get_embedding_kwargs( + input=text, model=model_name, **kwargs + ) + response = await self._create_embedding(embedding_kwargs) + + return EmbeddingModelResponse( + embedding=embedding_parser(response.data[0].embedding), + model_info=self.EMBEDDING_MODELS[model_name], + prompt_tokens_used=response.usage.prompt_tokens, + ) + + def _get_embedding_kwargs( + self, input: str | list[str], model: _ModelName, **kwargs + ) -> EmbeddingCreateParams: + """Get kwargs for an embedding API call + + Params: + input: Text body or list of text bodies to create embedding(s) from + model: Embedding model to use + + Returns: + The kwargs for the embedding API call + """ + kwargs = cast(EmbeddingCreateParams, kwargs) + + kwargs["input"] = input + kwargs["model"] = model + + if extra_headers := self._configuration.extra_request_headers: + # 'extra_headers' is not on CompletionCreateParams, but is on embedding.create() # noqa + kwargs["extra_headers"] = kwargs.get("extra_headers", {}) # type: ignore + kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore + + return kwargs + + def _create_embedding( + self, embedding_kwargs: EmbeddingCreateParams + ) -> Awaitable[CreateEmbeddingResponse]: + """Create an embedding using an OpenAI-like API with retry handling.""" + + @self._retry_api_request + async def _create_embedding_with_retry() -> CreateEmbeddingResponse: + return await self._client.embeddings.create(**embedding_kwargs) + + return _create_embedding_with_retry() + + +def format_function_def_for_openai(self: CompletionModelFunction) -> FunctionDefinition: + """Returns an OpenAI-consumable function definition""" + + return { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": { + name: param.to_dict() for name, param in self.parameters.items() + }, + "required": [ + name for name, param in self.parameters.items() if param.required + ], + }, + } diff --git a/forge/forge/llm/providers/anthropic.py b/forge/forge/llm/providers/anthropic.py index 4da5ed070cb1..2f8d21571dbf 100644 --- a/forge/forge/llm/providers/anthropic.py +++ b/forge/forge/llm/providers/anthropic.py @@ -2,7 +2,7 @@ import enum import logging -from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar import sentry_sdk import tenacity @@ -10,7 +10,9 @@ from anthropic import APIConnectionError, APIStatusError from pydantic import SecretStr -from forge.llm.providers.schema import ( +from forge.models.config import UserConfigurable + +from .schema import ( AssistantChatMessage, AssistantFunctionCall, AssistantToolCall, @@ -27,8 +29,6 @@ ModelTokenizer, ToolResultMessage, ) -from forge.models.config import UserConfigurable - from .utils import validate_tool_calls if TYPE_CHECKING: @@ -77,10 +77,6 @@ class AnthropicModelName(str, enum.Enum): } -class AnthropicConfiguration(ModelProviderConfiguration): - fix_failed_parse_tries: int = UserConfigurable(3) - - class AnthropicCredentials(ModelProviderCredentials): """Credentials for Anthropic.""" @@ -101,7 +97,6 @@ def get_api_access_kwargs(self) -> dict[str, str]: class AnthropicSettings(ModelProviderSettings): - configuration: AnthropicConfiguration # type: ignore credentials: Optional[AnthropicCredentials] # type: ignore budget: ModelProviderBudget # type: ignore @@ -110,15 +105,12 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti default_settings = AnthropicSettings( name="anthropic_provider", description="Provides access to Anthropic's API.", - configuration=AnthropicConfiguration( - retries_per_request=7, - ), + configuration=ModelProviderConfiguration(), credentials=None, budget=ModelProviderBudget(), ) _settings: AnthropicSettings - _configuration: AnthropicConfiguration _credentials: AnthropicCredentials _budget: ModelProviderBudget @@ -140,7 +132,12 @@ def __init__( **self._credentials.get_api_access_kwargs() # type: ignore ) - async def get_available_models(self) -> list[ChatModelInfo[AnthropicModelName]]: + async def get_available_models(self) -> Sequence[ChatModelInfo[AnthropicModelName]]: + return await self.get_available_chat_models() + + async def get_available_chat_models( + self, + ) -> Sequence[ChatModelInfo[AnthropicModelName]]: return list(ANTHROPIC_CHAT_MODELS.values()) def get_token_limit(self, model_name: AnthropicModelName) -> int: diff --git a/forge/forge/llm/providers/groq.py b/forge/forge/llm/providers/groq.py index 70996c132e0c..dc9e77e0b8d4 100644 --- a/forge/forge/llm/providers/groq.py +++ b/forge/forge/llm/providers/groq.py @@ -2,24 +2,16 @@ import enum import logging -from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, TypeVar +from typing import Any, Optional -import sentry_sdk -import tenacity import tiktoken -from groq import APIConnectionError, APIStatusError from pydantic import SecretStr -from forge.json.parsing import json_loads -from forge.llm.providers.schema import ( - AssistantChatMessage, - AssistantFunctionCall, - AssistantToolCall, - BaseChatModelProvider, - ChatMessage, +from forge.models.config import UserConfigurable + +from ._openai_base import BaseOpenAIChatProvider +from .schema import ( ChatModelInfo, - ChatModelResponse, - CompletionModelFunction, ModelProviderBudget, ModelProviderConfiguration, ModelProviderCredentials, @@ -27,18 +19,6 @@ ModelProviderSettings, ModelTokenizer, ) -from forge.models.config import UserConfigurable - -from .openai import format_function_def_for_openai -from .utils import validate_tool_calls - -if TYPE_CHECKING: - from groq.types.chat import ChatCompletion, CompletionCreateParams - from groq.types.chat.chat_completion_message import ChatCompletionMessage - from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam - -_T = TypeVar("_T") -_P = ParamSpec("_P") class GroqModelName(str, enum.Enum): @@ -87,10 +67,6 @@ class GroqModelName(str, enum.Enum): } -class GroqConfiguration(ModelProviderConfiguration): - fix_failed_parse_tries: int = UserConfigurable(3) - - class GroqCredentials(ModelProviderCredentials): """Credentials for Groq.""" @@ -111,24 +87,24 @@ def get_api_access_kwargs(self) -> dict[str, str]: class GroqSettings(ModelProviderSettings): - configuration: GroqConfiguration # type: ignore credentials: Optional[GroqCredentials] # type: ignore budget: ModelProviderBudget # type: ignore -class GroqProvider(BaseChatModelProvider[GroqModelName, GroqSettings]): +class GroqProvider(BaseOpenAIChatProvider[GroqModelName, GroqSettings]): + CHAT_MODELS = GROQ_CHAT_MODELS + MODELS = CHAT_MODELS + default_settings = GroqSettings( name="groq_provider", description="Provides access to Groq's API.", - configuration=GroqConfiguration( - retries_per_request=7, - ), + configuration=ModelProviderConfiguration(), credentials=None, budget=ModelProviderBudget(), ) _settings: GroqSettings - _configuration: GroqConfiguration + _configuration: ModelProviderConfiguration _credentials: GroqCredentials _budget: ModelProviderBudget @@ -137,11 +113,6 @@ def __init__( settings: Optional[GroqSettings] = None, logger: Optional[logging.Logger] = None, ): - if not settings: - settings = self.default_settings.copy(deep=True) - if not settings.credentials: - settings.credentials = GroqCredentials.from_env() - super(GroqProvider, self).__init__(settings=settings, logger=logger) from groq import AsyncGroq @@ -150,284 +121,6 @@ def __init__( **self._credentials.get_api_access_kwargs() # type: ignore ) - async def get_available_models(self) -> list[ChatModelInfo[GroqModelName]]: - _models = (await self._client.models.list()).data - return [GROQ_CHAT_MODELS[m.id] for m in _models if m.id in GROQ_CHAT_MODELS] - - def get_token_limit(self, model_name: GroqModelName) -> int: - """Get the token limit for a given model.""" - return GROQ_CHAT_MODELS[model_name].max_tokens - def get_tokenizer(self, model_name: GroqModelName) -> ModelTokenizer[Any]: # HACK: No official tokenizer is available for Groq return tiktoken.encoding_for_model("gpt-3.5-turbo") - - def count_tokens(self, text: str, model_name: GroqModelName) -> int: - return len(self.get_tokenizer(model_name).encode(text)) - - def count_message_tokens( - self, - messages: ChatMessage | list[ChatMessage], - model_name: GroqModelName, - ) -> int: - if isinstance(messages, ChatMessage): - messages = [messages] - # HACK: No official tokenizer (for text or messages) is available for Groq. - # Token overhead of messages is unknown and may be inaccurate. - return self.count_tokens( - "\n\n".join(f"{m.role.upper()}: {m.content}" for m in messages), model_name - ) - - async def create_chat_completion( - self, - model_prompt: list[ChatMessage], - model_name: GroqModelName, - completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, - functions: Optional[list[CompletionModelFunction]] = None, - max_output_tokens: Optional[int] = None, - prefill_response: str = "", - **kwargs, - ) -> ChatModelResponse[_T]: - """Create a completion using the Groq API.""" - groq_messages, completion_kwargs = self._get_chat_completion_args( - prompt_messages=model_prompt, - functions=functions, - max_output_tokens=max_output_tokens, - **kwargs, - ) - - total_cost = 0.0 - attempts = 0 - while True: - completion_kwargs["messages"] = groq_messages.copy() - _response, _cost, t_input, t_output = await self._create_chat_completion( - model=model_name, - completion_kwargs=completion_kwargs, - ) - total_cost += _cost - - # If parsing the response fails, append the error to the prompt, and let the - # LLM fix its mistake(s). - attempts += 1 - parse_errors: list[Exception] = [] - - _assistant_msg = _response.choices[0].message - - tool_calls, _errors = self._parse_assistant_tool_calls(_assistant_msg) - parse_errors += _errors - - # Validate tool calls - if not parse_errors and tool_calls and functions: - parse_errors += validate_tool_calls(tool_calls, functions) - - assistant_msg = AssistantChatMessage( - content=_assistant_msg.content or "", - tool_calls=tool_calls or None, - ) - - parsed_result: _T = None # type: ignore - if not parse_errors: - try: - parsed_result = completion_parser(assistant_msg) - except Exception as e: - parse_errors.append(e) - - if not parse_errors: - if attempts > 1: - self._logger.debug( - f"Total cost for {attempts} attempts: ${round(total_cost, 5)}" - ) - - return ChatModelResponse( - response=AssistantChatMessage( - content=_assistant_msg.content or "", - tool_calls=tool_calls or None, - ), - parsed_result=parsed_result, - model_info=GROQ_CHAT_MODELS[model_name], - prompt_tokens_used=t_input, - completion_tokens_used=t_output, - ) - - else: - self._logger.debug( - f"Parsing failed on response: '''{_assistant_msg}'''" - ) - parse_errors_fmt = "\n\n".join( - f"{e.__class__.__name__}: {e}" for e in parse_errors - ) - self._logger.warning( - f"Parsing attempt #{attempts} failed: {parse_errors_fmt}" - ) - for e in parse_errors: - sentry_sdk.capture_exception( - error=e, - extras={"assistant_msg": _assistant_msg, "i_attempt": attempts}, - ) - - if attempts < self._configuration.fix_failed_parse_tries: - groq_messages.append( - _assistant_msg.dict(exclude_none=True) # type: ignore - ) - groq_messages.append( - { - "role": "system", - "content": ( - f"ERROR PARSING YOUR RESPONSE:\n\n{parse_errors_fmt}" - ), - } - ) - continue - else: - raise parse_errors[0] - - def _get_chat_completion_args( - self, - prompt_messages: list[ChatMessage], - functions: Optional[list[CompletionModelFunction]] = None, - max_output_tokens: Optional[int] = None, - **kwargs, # type: ignore - ) -> tuple[list[ChatCompletionMessageParam], CompletionCreateParams]: - """Prepare chat completion arguments and keyword arguments for API call. - - Args: - model_prompt: List of ChatMessages. - functions: Optional list of functions available to the LLM. - kwargs: Additional keyword arguments. - - Returns: - list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call - dict[str, Any]: Any other kwargs for the OpenAI call - """ - kwargs: CompletionCreateParams = kwargs # type: ignore - if max_output_tokens: - kwargs["max_tokens"] = max_output_tokens - - if functions: - kwargs["tools"] = [ - {"type": "function", "function": format_function_def_for_openai(f)} - for f in functions - ] - if len(functions) == 1: - # force the model to call the only specified function - kwargs["tool_choice"] = { - "type": "function", - "function": {"name": functions[0].name}, - } - - if extra_headers := self._configuration.extra_request_headers: - # 'extra_headers' is not on CompletionCreateParams, but is on chat.create() - kwargs["extra_headers"] = kwargs.get("extra_headers", {}) # type: ignore - kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore - - groq_messages: list[ChatCompletionMessageParam] = [ - message.dict( # type: ignore - include={"role", "content", "tool_calls", "tool_call_id", "name"}, - exclude_none=True, - ) - for message in prompt_messages - ] - - if "messages" in kwargs: - groq_messages += kwargs["messages"] - del kwargs["messages"] # type: ignore - messages are added back later - - return groq_messages, kwargs - - async def _create_chat_completion( - self, model: GroqModelName, completion_kwargs: CompletionCreateParams - ) -> tuple[ChatCompletion, float, int, int]: - """ - Create a chat completion using the Groq API with retry handling. - - Params: - completion_kwargs: Keyword arguments for an Groq Messages API call - - Returns: - Message: The message completion object - float: The cost ($) of this completion - int: Number of input tokens used - int: Number of output tokens used - """ - - @self._retry_api_request - async def _create_chat_completion_with_retry() -> ChatCompletion: - return await self._client.chat.completions.create( - model=model, **completion_kwargs # type: ignore - ) - - response = await _create_chat_completion_with_retry() - - if not response.usage: - self._logger.warning( - "Groq chat completion response does not contain a usage field", - response, - ) - return response, 0, 0, 0 - else: - cost = self._budget.update_usage_and_cost( - model_info=GROQ_CHAT_MODELS[model], - input_tokens_used=response.usage.prompt_tokens, - output_tokens_used=response.usage.completion_tokens, - ) - return ( - response, - cost, - response.usage.prompt_tokens, - response.usage.completion_tokens, - ) - - def _parse_assistant_tool_calls( - self, assistant_message: ChatCompletionMessage, compat_mode: bool = False - ): - tool_calls: list[AssistantToolCall] = [] - parse_errors: list[Exception] = [] - - if assistant_message.tool_calls: - for _tc in assistant_message.tool_calls: - try: - parsed_arguments = json_loads(_tc.function.arguments) - except Exception as e: - err_message = ( - f"Decoding arguments for {_tc.function.name} failed: " - + str(e.args[0]) - ) - parse_errors.append( - type(e)(err_message, *e.args[1:]).with_traceback( - e.__traceback__ - ) - ) - continue - - tool_calls.append( - AssistantToolCall( - id=_tc.id, - type=_tc.type, - function=AssistantFunctionCall( - name=_tc.function.name, - arguments=parsed_arguments, - ), - ) - ) - - # If parsing of all tool calls succeeds in the end, we ignore any issues - if len(tool_calls) == len(assistant_message.tool_calls): - parse_errors = [] - - return tool_calls, parse_errors - - def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]: - return tenacity.retry( - retry=( - tenacity.retry_if_exception_type(APIConnectionError) - | tenacity.retry_if_exception( - lambda e: isinstance(e, APIStatusError) and e.status_code >= 500 - ) - ), - wait=tenacity.wait_exponential(), - stop=tenacity.stop_after_attempt(self._configuration.retries_per_request), - after=tenacity.after_log(self._logger, logging.DEBUG), - )(func) - - def __repr__(self): - return "GroqProvider()" diff --git a/forge/forge/llm/providers/multi.py b/forge/forge/llm/providers/multi.py index 47892adb5b1f..93784cd2eae5 100644 --- a/forge/forge/llm/providers/multi.py +++ b/forge/forge/llm/providers/multi.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Iterator, Optional, TypeVar +from typing import Any, Callable, Iterator, Optional, Sequence, TypeVar from pydantic import ValidationError @@ -56,10 +56,14 @@ def __init__( self._provider_instances = {} - async def get_available_models(self) -> list[ChatModelInfo[ModelName]]: + async def get_available_models(self) -> Sequence[ChatModelInfo[ModelName]]: + # TODO: support embeddings + return await self.get_available_chat_models() + + async def get_available_chat_models(self) -> Sequence[ChatModelInfo[ModelName]]: models = [] for provider in self.get_available_providers(): - models.extend(await provider.get_available_models()) + models.extend(await provider.get_available_chat_models()) return models def get_token_limit(self, model_name: ModelName) -> int: diff --git a/forge/forge/llm/providers/openai.py b/forge/forge/llm/providers/openai.py index a4dc2cacf32b..024f1286a8d2 100644 --- a/forge/forge/llm/providers/openai.py +++ b/forge/forge/llm/providers/openai.py @@ -2,47 +2,33 @@ import logging import os from pathlib import Path -from typing import ( - Any, - Callable, - Coroutine, - Iterator, - Optional, - ParamSpec, - TypeVar, - cast, -) +from typing import Any, Callable, Iterator, Mapping, Optional, ParamSpec, TypeVar, cast -import sentry_sdk import tenacity import tiktoken import yaml from openai._exceptions import APIStatusError, RateLimitError -from openai.types import CreateEmbeddingResponse +from openai.types import EmbeddingCreateParams from openai.types.chat import ( - ChatCompletion, - ChatCompletionAssistantMessageParam, ChatCompletionMessage, ChatCompletionMessageParam, + CompletionCreateParams, ) -from openai.types.shared_params import FunctionDefinition from pydantic import SecretStr from forge.json.parsing import json_loads -from forge.llm.providers.schema import ( - AssistantChatMessage, - AssistantFunctionCall, +from forge.models.config import UserConfigurable +from forge.models.json_schema import JSONSchema + +from ._openai_base import BaseOpenAIChatProvider, BaseOpenAIEmbeddingProvider +from .schema import ( AssistantToolCall, AssistantToolCallDict, - BaseChatModelProvider, - BaseEmbeddingModelProvider, ChatMessage, ChatModelInfo, - ChatModelResponse, CompletionModelFunction, Embedding, EmbeddingModelInfo, - EmbeddingModelResponse, ModelProviderBudget, ModelProviderConfiguration, ModelProviderCredentials, @@ -50,10 +36,6 @@ ModelProviderSettings, ModelTokenizer, ) -from forge.models.config import UserConfigurable -from forge.models.json_schema import JSONSchema - -from .utils import validate_tool_calls _T = TypeVar("_T") _P = ParamSpec("_P") @@ -221,16 +203,15 @@ class OpenAIModelName(str, enum.Enum): copy_info.has_function_call_api = False -OPEN_AI_MODELS = { +OPEN_AI_MODELS: Mapping[ + OpenAIModelName, + ChatModelInfo[OpenAIModelName] | EmbeddingModelInfo[OpenAIModelName], +] = { **OPEN_AI_CHAT_MODELS, **OPEN_AI_EMBEDDING_MODELS, } -class OpenAIConfiguration(ModelProviderConfiguration): - fix_failed_parse_tries: int = UserConfigurable(3) - - class OpenAICredentials(ModelProviderCredentials): """Credentials for OpenAI.""" @@ -308,27 +289,28 @@ def _get_azure_access_kwargs(self, model: str) -> dict[str, str]: class OpenAISettings(ModelProviderSettings): - configuration: OpenAIConfiguration # type: ignore credentials: Optional[OpenAICredentials] # type: ignore budget: ModelProviderBudget # type: ignore class OpenAIProvider( - BaseChatModelProvider[OpenAIModelName, OpenAISettings], - BaseEmbeddingModelProvider[OpenAIModelName, OpenAISettings], + BaseOpenAIChatProvider[OpenAIModelName, OpenAISettings], + BaseOpenAIEmbeddingProvider[OpenAIModelName, OpenAISettings], ): + MODELS = OPEN_AI_MODELS + CHAT_MODELS = OPEN_AI_CHAT_MODELS + EMBEDDING_MODELS = OPEN_AI_EMBEDDING_MODELS + default_settings = OpenAISettings( name="openai_provider", description="Provides access to OpenAI's API.", - configuration=OpenAIConfiguration( - retries_per_request=7, - ), + configuration=ModelProviderConfiguration(), credentials=None, budget=ModelProviderBudget(), ) _settings: OpenAISettings - _configuration: OpenAIConfiguration + _configuration: ModelProviderConfiguration _credentials: OpenAICredentials _budget: ModelProviderBudget @@ -337,11 +319,6 @@ def __init__( settings: Optional[OpenAISettings] = None, logger: Optional[logging.Logger] = None, ): - if not settings: - settings = self.default_settings.copy(deep=True) - if not settings.credentials: - settings.credentials = OpenAICredentials.from_env() - super(OpenAIProvider, self).__init__(settings=settings, logger=logger) if self._credentials.api_type == SecretStr("azure"): @@ -359,21 +336,9 @@ def __init__( **self._credentials.get_api_access_kwargs() # type: ignore ) - async def get_available_models(self) -> list[ChatModelInfo[OpenAIModelName]]: - _models = (await self._client.models.list()).data - return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS] - - def get_token_limit(self, model_name: OpenAIModelName) -> int: - """Get the token limit for a given model.""" - return OPEN_AI_MODELS[model_name].max_tokens - def get_tokenizer(self, model_name: OpenAIModelName) -> ModelTokenizer[int]: return tiktoken.encoding_for_model(model_name) - def count_tokens(self, text: str, model_name: OpenAIModelName) -> int: - encoding = self.get_tokenizer(model_name) - return len(encoding.encode(text)) - def count_message_tokens( self, messages: ChatMessage | list[ChatMessage], @@ -387,338 +352,87 @@ def count_message_tokens( 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n ) tokens_per_name = -1 # if there's a name, the role is omitted - encoding_model = "gpt-3.5-turbo" + # TODO: check if this is still valid for gpt-4o elif model_name.startswith("gpt-4"): tokens_per_message = 3 tokens_per_name = 1 - encoding_model = "gpt-4" else: raise NotImplementedError( f"count_message_tokens() is not implemented for model {model_name}.\n" - " See https://github.com/openai/openai-python/blob/main/chatml.md for" - " information on how messages are converted to tokens." - ) - try: - encoding = tiktoken.encoding_for_model(encoding_model) - except KeyError: - logging.getLogger(__class__.__name__).warning( - f"Model {model_name} not found. Defaulting to cl100k_base encoding." + "See https://github.com/openai/openai-python/blob/120d225b91a8453e15240a49fb1c6794d8119326/chatml.md " # noqa + "for information on how messages are converted to tokens." ) - encoding = tiktoken.get_encoding("cl100k_base") + tokenizer = self.get_tokenizer(model_name) num_tokens = 0 for message in messages: num_tokens += tokens_per_message for key, value in message.dict().items(): - num_tokens += len(encoding.encode(value)) + num_tokens += len(tokenizer.encode(value)) if key == "name": num_tokens += tokens_per_name num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens - async def create_chat_completion( - self, - model_prompt: list[ChatMessage], - model_name: OpenAIModelName, - completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, - functions: Optional[list[CompletionModelFunction]] = None, - max_output_tokens: Optional[int] = None, - prefill_response: str = "", # not supported by OpenAI - **kwargs, - ) -> ChatModelResponse[_T]: - """Create a completion using the OpenAI API and parse it.""" - - openai_messages, completion_kwargs = self._get_chat_completion_args( - model_prompt=model_prompt, - model_name=model_name, - functions=functions, - max_tokens=max_output_tokens, - **kwargs, - ) - tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs) - - total_cost = 0.0 - attempts = 0 - while True: - _response, _cost, t_input, t_output = await self._create_chat_completion( - messages=openai_messages, - **completion_kwargs, - ) - total_cost += _cost - - # If parsing the response fails, append the error to the prompt, and let the - # LLM fix its mistake(s). - attempts += 1 - parse_errors: list[Exception] = [] - - _assistant_msg = _response.choices[0].message - - tool_calls, _errors = self._parse_assistant_tool_calls( - _assistant_msg, tool_calls_compat_mode - ) - parse_errors += _errors - - # Validate tool calls - if not parse_errors and tool_calls and functions: - parse_errors += validate_tool_calls(tool_calls, functions) - - assistant_msg = AssistantChatMessage( - content=_assistant_msg.content or "", - tool_calls=tool_calls or None, - ) - - parsed_result: _T = None # type: ignore - if not parse_errors: - try: - parsed_result = completion_parser(assistant_msg) - except Exception as e: - parse_errors.append(e) - - if not parse_errors: - if attempts > 1: - self._logger.debug( - f"Total cost for {attempts} attempts: ${round(total_cost, 5)}" - ) - - return ChatModelResponse( - response=AssistantChatMessage( - content=_assistant_msg.content or "", - tool_calls=tool_calls or None, - ), - parsed_result=parsed_result, - model_info=OPEN_AI_CHAT_MODELS[model_name], - prompt_tokens_used=t_input, - completion_tokens_used=t_output, - ) - - else: - self._logger.debug( - f"Parsing failed on response: '''{_assistant_msg}'''" - ) - parse_errors_fmt = "\n\n".join( - f"{e.__class__.__name__}: {e}" for e in parse_errors - ) - self._logger.warning( - f"Parsing attempt #{attempts} failed: {parse_errors_fmt}" - ) - for e in parse_errors: - sentry_sdk.capture_exception( - error=e, - extras={"assistant_msg": _assistant_msg, "i_attempt": attempts}, - ) - - if attempts < self._configuration.fix_failed_parse_tries: - openai_messages.append( - cast( - ChatCompletionAssistantMessageParam, - _assistant_msg.dict(exclude_none=True), - ) - ) - openai_messages.append( - { - "role": "system", - "content": ( - f"ERROR PARSING YOUR RESPONSE:\n\n{parse_errors_fmt}" - ), - } - ) - continue - else: - raise parse_errors[0] - - async def create_embedding( - self, - text: str, - model_name: OpenAIModelName, - embedding_parser: Callable[[Embedding], Embedding], - **kwargs, - ) -> EmbeddingModelResponse: - """Create an embedding using the OpenAI API.""" - embedding_kwargs = self._get_embedding_kwargs(model_name, **kwargs) - response = await self._create_embedding(text=text, **embedding_kwargs) - - response = EmbeddingModelResponse( - embedding=embedding_parser(response.data[0].embedding), - model_info=OPEN_AI_EMBEDDING_MODELS[model_name], - prompt_tokens_used=response.usage.prompt_tokens, - completion_tokens_used=0, - ) - self._budget.update_usage_and_cost( - model_info=response.model_info, - input_tokens_used=response.prompt_tokens_used, - ) - return response - def _get_chat_completion_args( self, - model_prompt: list[ChatMessage], - model_name: OpenAIModelName, + prompt_messages: list[ChatMessage], + model: OpenAIModelName, functions: Optional[list[CompletionModelFunction]] = None, + max_output_tokens: Optional[int] = None, **kwargs, - ) -> tuple[list[ChatCompletionMessageParam], dict[str, Any]]: - """Prepare chat completion arguments and keyword arguments for API call. + ) -> tuple[ + list[ChatCompletionMessageParam], CompletionCreateParams, dict[str, Any] + ]: + """Prepare keyword arguments for an OpenAI chat completion call Args: - model_prompt: List of ChatMessages. - model_name: The model to use. - functions: Optional list of functions available to the LLM. - kwargs: Additional keyword arguments. + prompt_messages: List of ChatMessages + model: The model to use + functions (optional): List of functions available to the LLM + max_output_tokens (optional): Maximum number of tokens to generate Returns: list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call - dict[str, Any]: Any other kwargs for the OpenAI call + CompletionCreateParams: Mapping of other kwargs for the OpenAI call + Mapping[str, Any]: Any keyword arguments to pass on to the completion parser """ - kwargs.update(self._credentials.get_model_access_kwargs(model_name)) - + tools_compat_mode = False if functions: - if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api: - kwargs["tools"] = [ - {"type": "function", "function": format_function_def_for_openai(f)} - for f in functions - ] - if len(functions) == 1: - # force the model to call the only specified function - kwargs["tool_choice"] = { - "type": "function", - "function": {"name": functions[0].name}, - } - else: + if not OPEN_AI_CHAT_MODELS[model].has_function_call_api: # Provide compatibility with older models - _functions_compat_fix_kwargs(functions, kwargs) - - if extra_headers := self._configuration.extra_request_headers: - kwargs["extra_headers"] = kwargs.get("extra_headers", {}) - kwargs["extra_headers"].update(extra_headers.copy()) - - if "messages" in kwargs: - model_prompt += kwargs["messages"] - del kwargs["messages"] - - openai_messages = [ - cast( - ChatCompletionMessageParam, - message.dict( - include={"role", "content", "tool_calls", "name"}, - exclude_none=True, - ), - ) - for message in model_prompt - ] - - return openai_messages, kwargs - - def _get_embedding_kwargs( - self, - model_name: OpenAIModelName, - **kwargs, - ) -> dict: - """Get kwargs for embedding API call. - - Args: - model: The model to use. - kwargs: Keyword arguments to override the default values. - - Returns: - The kwargs for the embedding API call. + _functions_compat_fix_kwargs(functions, prompt_messages) + tools_compat_mode = True + functions = None - """ - kwargs.update(self._credentials.get_model_access_kwargs(model_name)) - - if extra_headers := self._configuration.extra_request_headers: - kwargs["extra_headers"] = kwargs.get("extra_headers", {}) - kwargs["extra_headers"].update(extra_headers.copy()) - - return kwargs - - async def _create_chat_completion( - self, - messages: list[ChatCompletionMessageParam], - model: OpenAIModelName, - *_, - **kwargs, - ) -> tuple[ChatCompletion, float, int, int]: - """ - Create a chat completion using the OpenAI API with retry handling. - - Params: - openai_messages: List of OpenAI-consumable message dict objects - model: The model to use for the completion - - Returns: - ChatCompletion: The chat completion response object - float: The cost ($) of this completion - int: Number of prompt tokens used - int: Number of completion tokens used - """ - - @self._retry_api_request - async def _create_chat_completion_with_retry( - messages: list[ChatCompletionMessageParam], **kwargs - ) -> ChatCompletion: - return await self._client.chat.completions.create( - messages=messages, # type: ignore - **kwargs, - ) - - completion = await _create_chat_completion_with_retry( - messages, model=model, **kwargs + openai_messages, kwargs, parse_kwargs = super()._get_chat_completion_args( + prompt_messages=prompt_messages, + model=model, + functions=functions, + max_output_tokens=max_output_tokens, + **kwargs, ) + kwargs.update(self._credentials.get_model_access_kwargs(model)) # type: ignore - if completion.usage: - prompt_tokens_used = completion.usage.prompt_tokens - completion_tokens_used = completion.usage.completion_tokens - else: - prompt_tokens_used = completion_tokens_used = 0 + if tools_compat_mode: + parse_kwargs["compat_mode"] = True - cost = self._budget.update_usage_and_cost( - model_info=OPEN_AI_CHAT_MODELS[model], - input_tokens_used=prompt_tokens_used, - output_tokens_used=completion_tokens_used, - ) - self._logger.debug( - f"Completion usage: {prompt_tokens_used} input, " - f"{completion_tokens_used} output - ${round(cost, 5)}" - ) - return completion, cost, prompt_tokens_used, completion_tokens_used + return openai_messages, kwargs, parse_kwargs def _parse_assistant_tool_calls( - self, assistant_message: ChatCompletionMessage, compat_mode: bool = False + self, + assistant_message: ChatCompletionMessage, + compat_mode: bool = False, + **kwargs, ) -> tuple[list[AssistantToolCall], list[Exception]]: tool_calls: list[AssistantToolCall] = [] parse_errors: list[Exception] = [] - if assistant_message.tool_calls: - for _tc in assistant_message.tool_calls: - try: - parsed_arguments = json_loads(_tc.function.arguments) - except Exception as e: - err_message = ( - f"Decoding arguments for {_tc.function.name} failed: " - + str(e.args[0]) - ) - parse_errors.append( - type(e)(err_message, *e.args[1:]).with_traceback( - e.__traceback__ - ) - ) - continue - - tool_calls.append( - AssistantToolCall( - id=_tc.id, - type=_tc.type, - function=AssistantFunctionCall( - name=_tc.function.name, - arguments=parsed_arguments, - ), - ) - ) - - # If parsing of all tool calls succeeds in the end, we ignore any issues - if len(tool_calls) == len(assistant_message.tool_calls): - parse_errors = [] - - elif compat_mode and assistant_message.content: + if not compat_mode: + return super()._parse_assistant_tool_calls( + assistant_message=assistant_message, compat_mode=compat_mode, **kwargs + ) + elif assistant_message.content: try: tool_calls = list( _tool_calls_compat_extract_calls(assistant_message.content) @@ -728,21 +442,16 @@ def _parse_assistant_tool_calls( return tool_calls, parse_errors - def _create_embedding( - self, text: str, *_, **kwargs - ) -> Coroutine[None, None, CreateEmbeddingResponse]: - """Create an embedding using the OpenAI API with retry handling.""" - - @self._retry_api_request - async def _create_embedding_with_retry( - text: str, *_, **kwargs - ) -> CreateEmbeddingResponse: - return await self._client.embeddings.create( - input=[text], - **kwargs, - ) + def _get_embedding_kwargs( + self, input: str | list[str], model: OpenAIModelName, **kwargs + ) -> EmbeddingCreateParams: + kwargs = super()._get_embedding_kwargs(input=input, model=model, **kwargs) + kwargs.update(self._credentials.get_model_access_kwargs(model)) # type: ignore + return kwargs - return _create_embedding_with_retry(text, *_, **kwargs) + _get_embedding_kwargs.__doc__ = ( + BaseOpenAIEmbeddingProvider._get_embedding_kwargs.__doc__ + ) def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]: _log_retry_debug_message = tenacity.after_log(self._logger, logging.DEBUG) @@ -777,24 +486,6 @@ def __repr__(self): return "OpenAIProvider()" -def format_function_def_for_openai(self: CompletionModelFunction) -> FunctionDefinition: - """Returns an OpenAI-consumable function definition""" - - return { - "name": self.name, - "description": self.description, - "parameters": { - "type": "object", - "properties": { - name: param.to_dict() for name, param in self.parameters.items() - }, - "required": [ - name for name, param in self.parameters.items() if param.required - ], - }, - } - - def format_function_specs_as_typescript_ns( functions: list[CompletionModelFunction], ) -> str: @@ -871,7 +562,7 @@ def count_openai_functions_tokens( def _functions_compat_fix_kwargs( functions: list[CompletionModelFunction], - completion_kwargs: dict, + prompt_messages: list[ChatMessage], ): function_definitions = format_function_specs_as_typescript_ns(functions) function_call_schema = JSONSchema( @@ -902,7 +593,7 @@ def _functions_compat_fix_kwargs( }, ), ) - completion_kwargs["messages"] = [ + prompt_messages.append( ChatMessage.system( "# tool usage instructions\n\n" "Specify a '```tool_calls' block in your response," @@ -915,7 +606,7 @@ def _functions_compat_fix_kwargs( "For the function call itself, use one of the following" f" functions:\n\n{function_definitions}" ), - ] + ) def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]: diff --git a/forge/forge/llm/providers/schema.py b/forge/forge/llm/providers/schema.py index 0858823ddce2..11b522e77960 100644 --- a/forge/forge/llm/providers/schema.py +++ b/forge/forge/llm/providers/schema.py @@ -12,11 +12,12 @@ Literal, Optional, Protocol, + Sequence, TypedDict, TypeVar, ) -from pydantic import BaseModel, Field, SecretStr, validator +from pydantic import BaseModel, Field, SecretStr from forge.logging.utils import fmt_kwargs from forge.models.config import ( @@ -189,7 +190,8 @@ class ModelResponse(BaseModel): class ModelProviderConfiguration(SystemConfiguration): - retries_per_request: int = UserConfigurable() + retries_per_request: int = UserConfigurable(7) + fix_failed_parse_tries: int = UserConfigurable(3) extra_request_headers: dict[str, str] = Field(default_factory=dict) @@ -296,6 +298,12 @@ def __init__( self._logger = logger or logging.getLogger(self.__module__) + @abc.abstractmethod + async def get_available_models( + self, + ) -> Sequence["ChatModelInfo[_ModelName] | EmbeddingModelInfo[_ModelName]"]: + ... + @abc.abstractmethod def count_tokens(self, text: str, model_name: _ModelName) -> int: ... @@ -339,7 +347,7 @@ def decode(self, tokens: list[_T]) -> str: class EmbeddingModelInfo(ModelInfo[_ModelName]): """Struct for embedding model information.""" - service = ModelProviderService.EMBEDDING + service: Literal[ModelProviderService.EMBEDDING] = ModelProviderService.EMBEDDING # type: ignore # noqa max_tokens: int embedding_dimensions: int @@ -348,15 +356,16 @@ class EmbeddingModelResponse(ModelResponse): """Standard response struct for a response from an embedding model.""" embedding: Embedding = Field(default_factory=list) - - @validator("completion_tokens_used") - def _verify_no_completion_tokens_used(cls, v: int): - if v > 0: - raise ValueError("Embeddings should not have completion tokens used.") - return v + completion_tokens_used: int = Field(default=0, const=True) class BaseEmbeddingModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]): + @abc.abstractmethod + async def get_available_embedding_models( + self, + ) -> Sequence[EmbeddingModelInfo[_ModelName]]: + ... + @abc.abstractmethod async def create_embedding( self, @@ -376,7 +385,7 @@ async def create_embedding( class ChatModelInfo(ModelInfo[_ModelName]): """Struct for language model information.""" - service = ModelProviderService.CHAT + service: Literal[ModelProviderService.CHAT] = ModelProviderService.CHAT # type: ignore # noqa max_tokens: int has_function_call_api: bool = False @@ -390,7 +399,7 @@ class ChatModelResponse(ModelResponse, Generic[_T]): class BaseChatModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]): @abc.abstractmethod - async def get_available_models(self) -> list[ChatModelInfo[_ModelName]]: + async def get_available_chat_models(self) -> Sequence[ChatModelInfo[_ModelName]]: ... @abc.abstractmethod