Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat:Wrap Openllm, Fireworks and other services into the OpenAILLM class. #946

Merged
merged 22 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions metagpt/configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class LLMType(Enum):
OLLAMA = "ollama"
QIANFAN = "qianfan" # Baidu BCE
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
MOONSHOT = "moonshot"
MISTRAL = 'mistral'

def __missing__(self, key):
return self.OPENAI
Expand Down
21 changes: 17 additions & 4 deletions metagpt/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
from pydantic import BaseModel, ConfigDict

from metagpt.config2 import Config
from metagpt.configs.llm_config import LLMConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import create_llm_instance
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import (
CostManager,
FireworksCostManager,
TokenCostManager,
)
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo

Expand Down Expand Up @@ -80,18 +84,27 @@ def new_environ(self):
# self._llm = None
# return self._llm

def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
"""Return a CostManager instance"""
if llm_config.api_type == LLMType.FIREWORKS:
return FireworksCostManager()
elif llm_config.api_type == LLMType.OPEN_LLM:
return TokenCostManager()
better629 marked this conversation as resolved.
Show resolved Hide resolved
else:
return self.cost_manager

def llm(self) -> BaseLLM:
"""Return a LLM instance, fixme: support cache"""
# if self._llm is None:
self._llm = create_llm_instance(self.config.llm)
if self._llm.cost_manager is None:
self._llm.cost_manager = self.cost_manager
self._llm.cost_manager = self._select_costmanager(self.config.llm)
return self._llm

def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM:
"""Return a LLM instance, fixme: support cache"""
# if self._llm is None:
llm = create_llm_instance(llm_config)
if llm.cost_manager is None:
llm.cost_manager = self.cost_manager
llm.cost_manager = self._select_costmanager(llm_config)
return llm
4 changes: 0 additions & 4 deletions metagpt/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
@File : __init__.py
"""

from metagpt.provider.fireworks_api import FireworksLLM
from metagpt.provider.google_gemini_api import GeminiLLM
from metagpt.provider.ollama_api import OllamaLLM
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.provider.openai_api import OpenAILLM
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from metagpt.provider.azure_openai_api import AzureOpenAILLM
Expand All @@ -20,9 +18,7 @@
from metagpt.provider.dashscope_api import DashScopeLLM

__all__ = [
"FireworksLLM",
"GeminiLLM",
"OpenLLM",
"OpenAILLM",
"ZhiPuAILLM",
"AzureOpenAILLM",
Expand Down
121 changes: 0 additions & 121 deletions metagpt/provider/fireworks_api.py

This file was deleted.

8 changes: 6 additions & 2 deletions metagpt/provider/llm_provider_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ def get_provider(self, enum: LLMType):
return self.providers[enum]


def register_provider(key):
def register_provider(keys):
"""register provider to registry"""

def decorator(cls):
LLM_REGISTRY.register(key, cls)
if isinstance(keys, list):
for key in keys:
LLM_REGISTRY.register(key, cls)
else:
LLM_REGISTRY.register(keys, cls)
return cls

return decorator
Expand Down
36 changes: 0 additions & 36 deletions metagpt/provider/open_llm_api.py

This file was deleted.

49 changes: 29 additions & 20 deletions metagpt/provider/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, decode_image
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import CostManager, Costs, TokenCostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
Expand All @@ -50,7 +50,7 @@ def log_and_reraise(retry_state):
raise retry_state.outcome.exception()


@register_provider(LLMType.OPENAI)
@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT])
better629 marked this conversation as resolved.
Show resolved Hide resolved
class OpenAILLM(BaseLLM):
"""Check https://platform.openai.com/examples for examples"""

Expand Down Expand Up @@ -84,20 +84,39 @@ def _get_proxy_params(self) -> dict:

return params

async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=timeout), stream=True
)

usage = None
collected_messages = []
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
yield chunk_message
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], "finish_reason") else None
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
if finish_reason:
if hasattr(chunk, "usage"):
# Some services have usage as an attribute of the chunk, such as Fireworks
usage = CompletionUsage(**chunk.usage)
elif hasattr(chunk.choices[0], "usage"):
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
usage = CompletionUsage(**chunk.choices[0].usage)

log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
if not usage:
# Some services do not provide the usage attribute, such as OpenAI or OpenLLM
usage = self._calc_usage(messages, full_reply_content)

self._update_costs(usage)
return full_reply_content

def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self._get_max_tokens(messages),
"n": 1,
# "n": 1, # Some services do not provide this parameter, such as mistral
# "stop": None, # default it's None and gpt4-v can't have this one
"temperature": self.config.temperature,
"model": self.model,
Expand Down Expand Up @@ -126,18 +145,7 @@ async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""when streaming, print each token in place."""
if stream:
resp = self._achat_completion_stream(messages, timeout=timeout)

collected_messages = []
async for i in resp:
log_llm_stream(i)
collected_messages.append(i)
log_llm_stream("\n")

full_reply_content = "".join(collected_messages)
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
await self._achat_completion_stream(messages, timeout=timeout)

rsp = await self._achat_completion(messages, timeout=timeout)
return self.get_choice_text(rsp)
Expand Down Expand Up @@ -261,9 +269,10 @@ def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
if not self.config.calc_usage:
return usage

model = self.model if not isinstance(self.cost_manager, TokenCostManager) else "open-llm-model"
try:
usage.prompt_tokens = count_message_tokens(messages, self.model)
usage.completion_tokens = count_string_tokens(rsp, self.model)
usage.prompt_tokens = count_message_tokens(messages, model)
usage.completion_tokens = count_string_tokens(rsp, model)
except Exception as e:
logger.warning(f"usage calculation failed: {e}")

Expand Down
Loading
Loading