Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added deepseek support #180

Merged
merged 1 commit into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions gptme/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_summary_model,
)
from .tools import ToolUse
from .util import console

logger = logging.getLogger(__name__)

Expand All @@ -41,7 +42,7 @@ def init_llm(llm: str):
init_anthropic(config)
assert get_anthropic_client()
else:
print(f"Error: Unknown LLM: {llm}")
console.log(f"Error: Unknown LLM: {llm}")
sys.exit(1)


Expand All @@ -58,7 +59,7 @@ def reply(messages: list[Message], model: str, stream: bool = False) -> Message:

def _chat_complete(messages: list[Message], model: str) -> str:
provider = _client_to_provider()
if provider in ["openai", "azure", "openrouter"]:
if provider in PROVIDERS_OPENAI:
return chat_openai(messages, model)
elif provider == "anthropic":
return chat_anthropic(messages, model)
Expand All @@ -68,7 +69,7 @@ def _chat_complete(messages: list[Message], model: str) -> str:

def _stream(messages: list[Message], model: str) -> Iterator[str]:
provider = _client_to_provider()
if provider in ["openai", "azure", "openrouter"]:
if provider in PROVIDERS_OPENAI:
return stream_openai(messages, model)
elif provider == "anthropic":
return stream_anthropic(messages, model)
Expand Down
8 changes: 7 additions & 1 deletion gptme/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def init(provider: Provider, config: Config):
elif provider == "groq":
api_key = config.get_env_required("GROQ_API_KEY")
openai = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
elif provider == "deepseek":
api_key = config.get_env_required("DEEPSEEK_API_KEY")
openai = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1")
elif provider == "local":
# OPENAI_API_BASE renamed to OPENAI_BASE_URL: https://github.com/openai/openai-python/issues/745
api_base = config.get_env("OPENAI_API_BASE")
Expand All @@ -61,6 +64,7 @@ def init(provider: Provider, config: Config):


def get_provider() -> Provider | None:
# used when checking for provider-specific capabilities
if not openai:
return None
if "openai.com" in str(openai.base_url):
Expand All @@ -71,7 +75,9 @@ def get_provider() -> Provider | None:
return "groq"
if "x.ai" in str(openai.base_url):
return "xai"
return None
if "deepseek.com" in str(openai.base_url):
return "deepseek"
return "local"


def get_client() -> "OpenAI | None":
Expand Down
6 changes: 4 additions & 2 deletions gptme/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ class _ModelDictMeta(TypedDict):


# available providers
Provider = Literal["openai", "anthropic", "azure", "openrouter", "groq", "xai", "local"]
Provider = Literal[
"openai", "anthropic", "azure", "openrouter", "groq", "xai", "deepseek", "local"
]
PROVIDERS: list[Provider] = cast(list[Provider], get_args(Provider))
PROVIDERS_OPENAI: list[Provider]
PROVIDERS_OPENAI = ["openai", "azure", "openrouter", "xai", "groq", "local"]
PROVIDERS_OPENAI = ["openai", "azure", "openrouter", "xai", "groq", "deepseek", "local"]

# default model
DEFAULT_MODEL: ModelMeta | None = None
Expand Down
Loading