Skip to content

Commit

Permalink
add deepseek provider
Browse files Browse the repository at this point in the history
  • Loading branch information
djcopley committed Jan 26, 2025
1 parent 52e11b3 commit a7d6dd6
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/shelloracle/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ def __get__(self, instance: Provider, owner: type[Provider]) -> T:


def _providers() -> dict[str, type[Provider]]:
from shelloracle.providers.deepseek import Deepseek
from shelloracle.providers.localai import LocalAI
from shelloracle.providers.ollama import Ollama
from shelloracle.providers.openai import OpenAI
from shelloracle.providers.xai import XAI

return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI, XAI.name: XAI}
return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI, XAI.name: XAI, Deepseek.name: Deepseek}


def get_provider(name: str) -> type[Provider]:
Expand Down
32 changes: 32 additions & 0 deletions src/shelloracle/providers/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from collections.abc import AsyncIterator

from openai import APIError, AsyncOpenAI

from shelloracle.providers import Provider, ProviderError, Setting, system_prompt


class Deepseek(Provider):
name = "Deepseek"

api_key = Setting(default="")
model = Setting(default="deepseek-chat")

def __init__(self):
if not self.api_key:
msg = "No API key provided"
raise ProviderError(msg)
self.client = AsyncOpenAI(base_url="https://api.deepseek.com/v1", api_key=self.api_key)

async def generate(self, prompt: str) -> AsyncIterator[str]:
try:
stream = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
stream=True,
)
async for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except APIError as e:
msg = f"Something went wrong while querying Deepseek: {e}"
raise ProviderError(msg) from e
41 changes: 41 additions & 0 deletions tests/providers/test_deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from shelloracle.providers.deepseek import Deepseek


class TestOpenAI:
@pytest.fixture
def deepseek_config(self, set_config):
config = {
"shelloracle": {"provider": "Deepseek"},
"provider": {
"Deepseek": {
"api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"model": "grok-beta",
}
},
}
set_config(config)

@pytest.fixture
def deepseek_instance(self, deepseek_config):
return Deepseek()

def test_name(self):
assert Deepseek.name == "Deepseek"

def test_api_key(self, deepseek_instance):
assert (
deepseek_instance.api_key
== "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
)

def test_model(self, deepseek_instance):
assert deepseek_instance.model == "grok-beta"

@pytest.mark.asyncio
async def test_generate(self, mock_asyncopenai, deepseek_instance):
result = ""
async for response in deepseek_instance.generate(""):
result += response
assert result == "head -c 100 /dev/urandom | hexdump -C"

0 comments on commit a7d6dd6

Please sign in to comment.