diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 36f5d7ae7e..3f42cdc9df 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -26,6 +26,8 @@ class LLMType(Enum): OLLAMA = "ollama" QIANFAN = "qianfan" # Baidu BCE DASHSCOPE = "dashscope" # Aliyun LingJi DashScope + MOONSHOT = "moonshot" + MISTRAL = 'mistral' def __missing__(self, key): return self.OPENAI diff --git a/metagpt/context.py b/metagpt/context.py index 3dfd52d588..0add4c71ae 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -12,10 +12,14 @@ from pydantic import BaseModel, ConfigDict from metagpt.config2 import Config -from metagpt.configs.llm_config import LLMConfig +from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import create_llm_instance -from metagpt.utils.cost_manager import CostManager +from metagpt.utils.cost_manager import ( + CostManager, + FireworksCostManager, + TokenCostManager, +) from metagpt.utils.git_repository import GitRepository from metagpt.utils.project_repo import ProjectRepo @@ -80,12 +84,21 @@ def new_environ(self): # self._llm = None # return self._llm + def _select_costmanager(self, llm_config: LLMConfig) -> CostManager: + """Return a CostManager instance""" + if llm_config.api_type == LLMType.FIREWORKS: + return FireworksCostManager() + elif llm_config.api_type == LLMType.OPEN_LLM: + return TokenCostManager() + else: + return self.cost_manager + def llm(self) -> BaseLLM: """Return a LLM instance, fixme: support cache""" # if self._llm is None: self._llm = create_llm_instance(self.config.llm) if self._llm.cost_manager is None: - self._llm.cost_manager = self.cost_manager + self._llm.cost_manager = self._select_costmanager(self.config.llm) return self._llm def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM: @@ -93,5 +106,5 @@ def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLL # if self._llm is None: llm = create_llm_instance(llm_config) if llm.cost_manager is None: - llm.cost_manager = self.cost_manager + llm.cost_manager = self._select_costmanager(llm_config) return llm diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 44e6d3f3bb..ed49d01c96 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -6,10 +6,8 @@ @File : __init__.py """ -from metagpt.provider.fireworks_api import FireworksLLM from metagpt.provider.google_gemini_api import GeminiLLM from metagpt.provider.ollama_api import OllamaLLM -from metagpt.provider.open_llm_api import OpenLLM from metagpt.provider.openai_api import OpenAILLM from metagpt.provider.zhipuai_api import ZhiPuAILLM from metagpt.provider.azure_openai_api import AzureOpenAILLM @@ -20,9 +18,7 @@ from metagpt.provider.dashscope_api import DashScopeLLM __all__ = [ - "FireworksLLM", "GeminiLLM", - "OpenLLM", "OpenAILLM", "ZhiPuAILLM", "AzureOpenAILLM", diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py deleted file mode 100644 index f356c23c4e..0000000000 --- a/metagpt/provider/fireworks_api.py +++ /dev/null @@ -1,121 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : fireworks.ai's api - -import re - -from openai import APIConnectionError, AsyncStream -from openai.types import CompletionUsage -from openai.types.chat import ChatCompletionChunk -from tenacity import ( - after_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_random_exponential, -) - -from metagpt.configs.llm_config import LLMConfig, LLMType -from metagpt.logs import log_llm_stream, logger -from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import OpenAILLM, log_and_reraise -from metagpt.utils.cost_manager import CostManager - -MODEL_GRADE_TOKEN_COSTS = { - "-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition - "16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens - "80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B - "mixtral-8x7b": {"prompt": 0.4, "completion": 1.6}, -} - - -class FireworksCostManager(CostManager): - def model_grade_token_costs(self, model: str) -> dict[str, float]: - def _get_model_size(model: str) -> float: - size = re.findall(".*-([0-9.]+)b", model) - size = float(size[0]) if len(size) > 0 else -1 - return size - - if "mixtral-8x7b" in model: - token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] - else: - model_size = _get_model_size(model) - if 0 < model_size <= 16: - token_costs = MODEL_GRADE_TOKEN_COSTS["16"] - elif 16 < model_size <= 80: - token_costs = MODEL_GRADE_TOKEN_COSTS["80"] - else: - token_costs = MODEL_GRADE_TOKEN_COSTS["-1"] - return token_costs - - def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str): - """ - Refs to `https://app.fireworks.ai/pricing` **Developer pricing** - Update the total cost, prompt tokens, and completion tokens. - - Args: - prompt_tokens (int): The number of tokens used in the prompt. - completion_tokens (int): The number of tokens used in the completion. - model (str): The model used for the API call. - """ - self.total_prompt_tokens += prompt_tokens - self.total_completion_tokens += completion_tokens - - token_costs = self.model_grade_token_costs(model) - cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000 - self.total_cost += cost - logger.info( - f"Total running cost: ${self.total_cost:.4f}" - f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" - ) - - -@register_provider(LLMType.FIREWORKS) -class FireworksLLM(OpenAILLM): - def __init__(self, config: LLMConfig): - super().__init__(config=config) - self.auto_max_tokens = False - self.cost_manager = FireworksCostManager() - - def _make_client_kwargs(self) -> dict: - kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url) - return kwargs - - async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: - response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( - **self._cons_kwargs(messages), stream=True - ) - - collected_content = [] - usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) - # iterate through the stream of events - async for chunk in response: - if chunk.choices: - choice = chunk.choices[0] - choice_delta = choice.delta - finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None - if choice_delta.content: - collected_content.append(choice_delta.content) - log_llm_stream(choice_delta.content) - if finish_reason: - # fireworks api return usage when finish_reason is not None - usage = CompletionUsage(**chunk.usage) - log_llm_stream("\n") - - full_content = "".join(collected_content) - self._update_costs(usage) - return full_content - - @retry( - wait=wait_random_exponential(min=1, max=60), - stop=stop_after_attempt(6), - after=after_log(logger, logger.level("WARNING").name), - retry=retry_if_exception_type(APIConnectionError), - retry_error_callback=log_and_reraise, - ) - async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: - """when streaming, print each token in place.""" - if stream: - return await self._achat_completion_stream(messages) - rsp = await self._achat_completion(messages) - return self.get_choice_text(rsp) diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py index df89d36aae..4fd2b19783 100644 --- a/metagpt/provider/llm_provider_registry.py +++ b/metagpt/provider/llm_provider_registry.py @@ -21,11 +21,15 @@ def get_provider(self, enum: LLMType): return self.providers[enum] -def register_provider(key): +def register_provider(keys): """register provider to registry""" def decorator(cls): - LLM_REGISTRY.register(key, cls) + if isinstance(keys, list): + for key in keys: + LLM_REGISTRY.register(key, cls) + else: + LLM_REGISTRY.register(keys, cls) return cls return decorator diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py deleted file mode 100644 index 69371e3794..0000000000 --- a/metagpt/provider/open_llm_api.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : self-host open llm model with openai-compatible interface - -from openai.types import CompletionUsage - -from metagpt.configs.llm_config import LLMConfig, LLMType -from metagpt.logs import logger -from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import OpenAILLM -from metagpt.utils.cost_manager import TokenCostManager -from metagpt.utils.token_counter import count_message_tokens, count_string_tokens - - -@register_provider(LLMType.OPEN_LLM) -class OpenLLM(OpenAILLM): - def __init__(self, config: LLMConfig): - super().__init__(config) - self._cost_manager = TokenCostManager() - - def _make_client_kwargs(self) -> dict: - kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url) - return kwargs - - def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage: - usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) - if not self.config.calc_usage: - return usage - - try: - usage.prompt_tokens = count_message_tokens(messages, "open-llm-model") - usage.completion_tokens = count_string_tokens(rsp, "open-llm-model") - except Exception as e: - logger.error(f"usage calculation failed!: {e}") - - return usage diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 36d6f6d778..0a423f210f 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -30,7 +30,7 @@ from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message from metagpt.utils.common import CodeParser, decode_image -from metagpt.utils.cost_manager import CostManager +from metagpt.utils.cost_manager import CostManager, Costs, TokenCostManager from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( count_message_tokens, @@ -50,7 +50,7 @@ def log_and_reraise(retry_state): raise retry_state.outcome.exception() -@register_provider(LLMType.OPENAI) +@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT]) class OpenAILLM(BaseLLM): """Check https://platform.openai.com/examples for examples""" @@ -84,20 +84,39 @@ def _get_proxy_params(self) -> dict: return params - async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]: + async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages, timeout=timeout), stream=True ) - + usage = None + collected_messages = [] async for chunk in response: chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message - yield chunk_message + finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], "finish_reason") else None + log_llm_stream(chunk_message) + collected_messages.append(chunk_message) + if finish_reason: + if hasattr(chunk, "usage"): + # Some services have usage as an attribute of the chunk, such as Fireworks + usage = CompletionUsage(**chunk.usage) + elif hasattr(chunk.choices[0], "usage"): + # The usage of some services is an attribute of chunk.choices[0], such as Moonshot + usage = CompletionUsage(**chunk.choices[0].usage) + + log_llm_stream("\n") + full_reply_content = "".join(collected_messages) + if not usage: + # Some services do not provide the usage attribute, such as OpenAI or OpenLLM + usage = self._calc_usage(messages, full_reply_content) + + self._update_costs(usage) + return full_reply_content def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict: kwargs = { "messages": messages, "max_tokens": self._get_max_tokens(messages), - "n": 1, + # "n": 1, # Some services do not provide this parameter, such as mistral # "stop": None, # default it's None and gpt4-v can't have this one "temperature": self.config.temperature, "model": self.model, @@ -126,18 +145,7 @@ async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion: async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """when streaming, print each token in place.""" if stream: - resp = self._achat_completion_stream(messages, timeout=timeout) - - collected_messages = [] - async for i in resp: - log_llm_stream(i) - collected_messages.append(i) - log_llm_stream("\n") - - full_reply_content = "".join(collected_messages) - usage = self._calc_usage(messages, full_reply_content) - self._update_costs(usage) - return full_reply_content + await self._achat_completion_stream(messages, timeout=timeout) rsp = await self._achat_completion(messages, timeout=timeout) return self.get_choice_text(rsp) @@ -261,9 +269,10 @@ def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage: if not self.config.calc_usage: return usage + model = self.model if not isinstance(self.cost_manager, TokenCostManager) else "open-llm-model" try: - usage.prompt_tokens = count_message_tokens(messages, self.model) - usage.completion_tokens = count_string_tokens(rsp, self.model) + usage.prompt_tokens = count_message_tokens(messages, model) + usage.completion_tokens = count_string_tokens(rsp, model) except Exception as e: logger.warning(f"usage calculation failed: {e}") diff --git a/metagpt/utils/cost_manager.py b/metagpt/utils/cost_manager.py index efff07ae14..b871cef3ba 100644 --- a/metagpt/utils/cost_manager.py +++ b/metagpt/utils/cost_manager.py @@ -6,12 +6,13 @@ @Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting. """ +import re from typing import NamedTuple from pydantic import BaseModel from metagpt.logs import logger -from metagpt.utils.token_counter import TOKEN_COSTS +from metagpt.utils.token_counter import FIREWORKS_GRADE_TOKEN_COSTS, TOKEN_COSTS class Costs(NamedTuple): @@ -103,3 +104,44 @@ def update_cost(self, prompt_tokens, completion_tokens, model): self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens logger.info(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}") + + +class FireworksCostManager(CostManager): + def model_grade_token_costs(self, model: str) -> dict[str, float]: + def _get_model_size(model: str) -> float: + size = re.findall(".*-([0-9.]+)b", model) + size = float(size[0]) if len(size) > 0 else -1 + return size + + if "mixtral-8x7b" in model: + token_costs = FIREWORKS_GRADE_TOKEN_COSTS["mixtral-8x7b"] + else: + model_size = _get_model_size(model) + if 0 < model_size <= 16: + token_costs = FIREWORKS_GRADE_TOKEN_COSTS["16"] + elif 16 < model_size <= 80: + token_costs = FIREWORKS_GRADE_TOKEN_COSTS["80"] + else: + token_costs = FIREWORKS_GRADE_TOKEN_COSTS["-1"] + return token_costs + + def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str): + """ + Refs to `https://app.fireworks.ai/pricing` **Developer pricing** + Update the total cost, prompt tokens, and completion tokens. + + Args: + prompt_tokens (int): The number of tokens used in the prompt. + completion_tokens (int): The number of tokens used in the completion. + model (str): The model used for the API call. + """ + self.total_prompt_tokens += prompt_tokens + self.total_completion_tokens += completion_tokens + + token_costs = self.model_grade_token_costs(model) + cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000 + self.total_cost += cost + logger.info( + f"Total running cost: ${self.total_cost:.4f}" + f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + ) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 167a1d7559..c20caa8e1c 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -35,6 +35,14 @@ "glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens "glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, + "moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens + "moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024}, + "moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06}, + "open-mistral-7b": {"prompt": 0.00025, "completion": 0.00025}, + "open-mixtral-8x7b": {"prompt": 0.0007, "completion": 0.0007}, + "mistral-small-latest": {"prompt": 0.002, "completion": 0.006}, + "mistral-medium-latest": {"prompt": 0.0027, "completion": 0.0081}, + "mistral-large-latest": {"prompt": 0.008, "completion": 0.024}, } @@ -120,6 +128,14 @@ } +FIREWORKS_GRADE_TOKEN_COSTS = { + "-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition + "16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens + "80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B + "mixtral-8x7b": {"prompt": 0.4, "completion": 1.6}, +} + + TOKEN_MAX = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-0301": 4096, @@ -143,6 +159,14 @@ "glm-3-turbo": 128000, "glm-4": 128000, "gemini-pro": 32768, + "moonshot-v1-8k": 8192, + "moonshot-v1-32k": 32768, + "moonshot-v1-128k": 128000, + "open-mistral-7b": 8192, + "open-mixtral-8x7b": 32768, + "mistral-small-latest": 32768, + "mistral-medium-latest": 32768, + "mistral-large-latest": 32768, } diff --git a/tests/metagpt/provider/test_fireworks_llm.py b/tests/metagpt/provider/test_fireworks_llm.py deleted file mode 100644 index 1c1aa9caa9..0000000000 --- a/tests/metagpt/provider/test_fireworks_llm.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : the unittest of fireworks api - -import pytest -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from openai.types.completion_usage import CompletionUsage - -from metagpt.provider.fireworks_api import ( - MODEL_GRADE_TOKEN_COSTS, - FireworksCostManager, - FireworksLLM, -) -from metagpt.utils.cost_manager import Costs -from tests.metagpt.provider.mock_llm_config import mock_llm_config -from tests.metagpt.provider.req_resp_const import ( - get_openai_chat_completion, - get_openai_chat_completion_chunk, - llm_general_chat_funcs_test, - messages, - prompt, - resp_cont_tmpl, -) - -name = "fireworks" -resp_cont = resp_cont_tmpl.format(name=name) -default_resp = get_openai_chat_completion(name) -default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True) - - -def test_fireworks_costmanager(): - cost_manager = FireworksCostManager() - assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test") - assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat") - assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat") - assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat") - assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat") - assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat") - assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat") - - cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat") - assert cost_manager.total_cost == 0.5 - - -async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: - if stream: - - class Iterator(object): - async def __aiter__(self): - yield default_resp_chunk - - return Iterator() - else: - return default_resp - - -@pytest.mark.asyncio -async def test_fireworks_acompletion(mocker): - mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - - fireworks_llm = FireworksLLM(mock_llm_config) - fireworks_llm.model = "llama-v2-13b-chat" - - fireworks_llm._update_costs( - usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000) - ) - assert fireworks_llm.get_costs() == Costs( - total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0 - ) - - resp = await fireworks_llm.acompletion(messages) - assert resp.choices[0].message.content in resp_cont - - await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py deleted file mode 100644 index aa38b95a64..0000000000 --- a/tests/metagpt/provider/test_open_llm_api.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : - -import pytest -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from openai.types.completion_usage import CompletionUsage - -from metagpt.provider.open_llm_api import OpenLLM -from metagpt.utils.cost_manager import CostManager, Costs -from tests.metagpt.provider.mock_llm_config import mock_llm_config -from tests.metagpt.provider.req_resp_const import ( - get_openai_chat_completion, - get_openai_chat_completion_chunk, - llm_general_chat_funcs_test, - messages, - prompt, - resp_cont_tmpl, -) - -name = "llama2-7b" -resp_cont = resp_cont_tmpl.format(name=name) -default_resp = get_openai_chat_completion(name) - -default_resp_chunk = get_openai_chat_completion_chunk(name) - - -async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: - if stream: - - class Iterator(object): - async def __aiter__(self): - yield default_resp_chunk - - return Iterator() - else: - return default_resp - - -@pytest.mark.asyncio -async def test_openllm_acompletion(mocker): - mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - - openllm_llm = OpenLLM(mock_llm_config) - openllm_llm.model = "llama-v2-13b-chat" - - openllm_llm.cost_manager = CostManager() - openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) - assert openllm_llm.get_costs() == Costs( - total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0 - ) - - resp = await openllm_llm.acompletion(messages) - assert resp.choices[0].message.content in resp_cont - - await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 96c08a867a..3ce38d2a5a 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -1,10 +1,11 @@ import pytest from openai.types.chat import ( ChatCompletion, + ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall, ) -from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion import Choice, CompletionUsage from openai.types.chat.chat_completion_message_tool_call import Function from PIL import Image @@ -16,6 +17,22 @@ mock_llm_config, mock_llm_config_proxy, ) +from tests.metagpt.provider.req_resp_const import ( + get_openai_chat_completion, + get_openai_chat_completion_chunk, + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) + +name = "AI assistant" +resp_cont = resp_cont_tmpl.format(name=name) +default_resp = get_openai_chat_completion(name) + +default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True) + +usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202) @pytest.mark.asyncio @@ -121,3 +138,29 @@ async def test_gen_image(): images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json") assert images[0].size == (1024, 1024) + + +async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: + if stream: + + class Iterator(object): + async def __aiter__(self): + yield default_resp_chunk + + return Iterator() + else: + return default_resp + + +@pytest.mark.asyncio +async def test_openai_acompletion(mocker): + mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) + + llm = OpenAILLM(mock_llm_config) + + resp = await llm.acompletion(messages) + assert resp.choices[0].finish_reason == "stop" + assert resp.choices[0].message.content == resp_cont + assert resp.usage == usage + + await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)