From 949eaeec87f5eb941f2c40dda58224de174a924b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Sun, 27 Oct 2024 22:05:38 +0100 Subject: [PATCH] feat: added deepseek support (#180) --- gptme/llm.py | 7 ++++--- gptme/llm_openai.py | 8 +++++++- gptme/models.py | 6 ++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/gptme/llm.py b/gptme/llm.py index bc4f17d8..cc738cd5 100644 --- a/gptme/llm.py +++ b/gptme/llm.py @@ -25,6 +25,7 @@ get_summary_model, ) from .tools import ToolUse +from .util import console logger = logging.getLogger(__name__) @@ -41,7 +42,7 @@ def init_llm(llm: str): init_anthropic(config) assert get_anthropic_client() else: - print(f"Error: Unknown LLM: {llm}") + console.log(f"Error: Unknown LLM: {llm}") sys.exit(1) @@ -58,7 +59,7 @@ def reply(messages: list[Message], model: str, stream: bool = False) -> Message: def _chat_complete(messages: list[Message], model: str) -> str: provider = _client_to_provider() - if provider in ["openai", "azure", "openrouter"]: + if provider in PROVIDERS_OPENAI: return chat_openai(messages, model) elif provider == "anthropic": return chat_anthropic(messages, model) @@ -68,7 +69,7 @@ def _chat_complete(messages: list[Message], model: str) -> str: def _stream(messages: list[Message], model: str) -> Iterator[str]: provider = _client_to_provider() - if provider in ["openai", "azure", "openrouter"]: + if provider in PROVIDERS_OPENAI: return stream_openai(messages, model) elif provider == "anthropic": return stream_anthropic(messages, model) diff --git a/gptme/llm_openai.py b/gptme/llm_openai.py index e4dc7f1e..a06b8e27 100644 --- a/gptme/llm_openai.py +++ b/gptme/llm_openai.py @@ -46,6 +46,9 @@ def init(provider: Provider, config: Config): elif provider == "groq": api_key = config.get_env_required("GROQ_API_KEY") openai = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1") + elif provider == "deepseek": + api_key = config.get_env_required("DEEPSEEK_API_KEY") + openai = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1") elif provider == "local": # OPENAI_API_BASE renamed to OPENAI_BASE_URL: https://github.com/openai/openai-python/issues/745 api_base = config.get_env("OPENAI_API_BASE") @@ -61,6 +64,7 @@ def init(provider: Provider, config: Config): def get_provider() -> Provider | None: + # used when checking for provider-specific capabilities if not openai: return None if "openai.com" in str(openai.base_url): @@ -71,7 +75,9 @@ def get_provider() -> Provider | None: return "groq" if "x.ai" in str(openai.base_url): return "xai" - return None + if "deepseek.com" in str(openai.base_url): + return "deepseek" + return "local" def get_client() -> "OpenAI | None": diff --git a/gptme/models.py b/gptme/models.py index 7b787140..3a5c1ba1 100644 --- a/gptme/models.py +++ b/gptme/models.py @@ -37,10 +37,12 @@ class _ModelDictMeta(TypedDict): # available providers -Provider = Literal["openai", "anthropic", "azure", "openrouter", "groq", "xai", "local"] +Provider = Literal[ + "openai", "anthropic", "azure", "openrouter", "groq", "xai", "deepseek", "local" +] PROVIDERS: list[Provider] = cast(list[Provider], get_args(Provider)) PROVIDERS_OPENAI: list[Provider] -PROVIDERS_OPENAI = ["openai", "azure", "openrouter", "xai", "groq", "local"] +PROVIDERS_OPENAI = ["openai", "azure", "openrouter", "xai", "groq", "deepseek", "local"] # default model DEFAULT_MODEL: ModelMeta | None = None