Skip to content

Commit

Permalink
feat: wip deepseek support
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Oct 27, 2024
1 parent cea30cf commit 2488404
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
7 changes: 4 additions & 3 deletions gptme/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_summary_model,
)
from .tools import ToolUse
from .util import console

logger = logging.getLogger(__name__)

Expand All @@ -41,7 +42,7 @@ def init_llm(llm: str):
init_anthropic(config)
assert get_anthropic_client()
else:
print(f"Error: Unknown LLM: {llm}")
console.log(f"Error: Unknown LLM: {llm}")
sys.exit(1)


Expand All @@ -58,7 +59,7 @@ def reply(messages: list[Message], model: str, stream: bool = False) -> Message:

def _chat_complete(messages: list[Message], model: str) -> str:
provider = _client_to_provider()
if provider in ["openai", "azure", "openrouter"]:
if provider in PROVIDERS_OPENAI:
return chat_openai(messages, model)
elif provider == "anthropic":
return chat_anthropic(messages, model)
Expand All @@ -68,7 +69,7 @@ def _chat_complete(messages: list[Message], model: str) -> str:

def _stream(messages: list[Message], model: str) -> Iterator[str]:
provider = _client_to_provider()
if provider in ["openai", "azure", "openrouter"]:
if provider in PROVIDERS_OPENAI:
return stream_openai(messages, model)
elif provider == "anthropic":
return stream_anthropic(messages, model)
Expand Down
8 changes: 7 additions & 1 deletion gptme/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def init(provider: Provider, config: Config):
elif provider == "groq":
api_key = config.get_env_required("GROQ_API_KEY")
openai = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
elif provider == "deepseek":
api_key = config.get_env_required("DEEPSEEK_API_KEY")
openai = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1")
elif provider == "local":
# OPENAI_API_BASE renamed to OPENAI_BASE_URL: https://github.com/openai/openai-python/issues/745
api_base = config.get_env("OPENAI_API_BASE")
Expand All @@ -61,6 +64,7 @@ def init(provider: Provider, config: Config):


def get_provider() -> Provider | None:
# used when checking for provider-specific capabilities
if not openai:
return None
if "openai.com" in str(openai.base_url):
Expand All @@ -71,7 +75,9 @@ def get_provider() -> Provider | None:
return "groq"
if "x.ai" in str(openai.base_url):
return "xai"
return None
if "deepseek.com" in str(openai.base_url):
return "deepseek"
return "local"


def get_client() -> "OpenAI | None":
Expand Down
4 changes: 3 additions & 1 deletion gptme/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class _ModelDictMeta(TypedDict):


# available providers
Provider = Literal["openai", "anthropic", "azure", "openrouter", "groq", "xai", "local"]
Provider = Literal[
"openai", "anthropic", "azure", "openrouter", "groq", "xai", "deepseek", "local"
]
PROVIDERS: list[Provider] = cast(list[Provider], get_args(Provider))
PROVIDERS_OPENAI: list[Provider]
PROVIDERS_OPENAI = ["openai", "azure", "openrouter", "xai", "groq", "local"]
Expand Down

0 comments on commit 2488404

Please sign in to comment.