Skip to content

Commit

Permalink
fix: fixed support for groq (and deepseek?) (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare authored Oct 27, 2024
1 parent 65bdb8a commit cea30cf
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
4 changes: 2 additions & 2 deletions gptme/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class MessagePart(TypedDict, total=False):
def chat(messages: list[Message], model: str) -> str:
assert anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)
messages_dicts = msgs2dicts(messages, anthropic=True)
messages_dicts = msgs2dicts(messages, provider="anthropic")
response = anthropic.beta.prompt_caching.messages.create(
model=model,
messages=messages_dicts, # type: ignore
Expand All @@ -56,7 +56,7 @@ def chat(messages: list[Message], model: str) -> str:
def stream(messages: list[Message], model: str) -> Generator[str, None, None]:
assert anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)
messages_dicts = msgs2dicts(messages, anthropic=True)
messages_dicts = msgs2dicts(messages, provider="anthropic")
with anthropic.beta.prompt_caching.messages.stream(
model=model,
messages=messages_dicts, # type: ignore
Expand Down
18 changes: 16 additions & 2 deletions gptme/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def init(provider: Provider, config: Config):
assert openai, "Provider not initialized"


def get_provider() -> Provider | None:
if not openai:
return None
if "openai.com" in str(openai.base_url):
return "openai"
if "openrouter.ai" in str(openai.base_url):
return "openrouter"
if "groq.com" in str(openai.base_url):
return "groq"
if "x.ai" in str(openai.base_url):
return "xai"
return None


def get_client() -> "OpenAI | None":
return openai

Expand Down Expand Up @@ -88,7 +102,7 @@ def chat(messages: list[Message], model: str) -> str:

response = openai.chat.completions.create(
model=model,
messages=msgs2dicts(messages, openai=True), # type: ignore
messages=msgs2dicts(messages, provider=get_provider()), # type: ignore
temperature=TEMPERATURE if not is_o1 else NOT_GIVEN,
top_p=TOP_P if not is_o1 else NOT_GIVEN,
extra_headers=(
Expand All @@ -105,7 +119,7 @@ def stream(messages: list[Message], model: str) -> Generator[str, None, None]:
stop_reason = None
for chunk in openai.chat.completions.create(
model=model,
messages=msgs2dicts(_prep_o1(messages), openai=True), # type: ignore
messages=msgs2dicts(_prep_o1(messages), provider=get_provider()), # type: ignore
temperature=TEMPERATURE,
top_p=TOP_P,
stream=True,
Expand Down
29 changes: 16 additions & 13 deletions gptme/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from .codeblock import Codeblock
from .constants import ROLE_COLOR
from .models import Provider
from .util import console, get_tokenizer, rich_to_str

logger = logging.getLogger(__name__)
Expand All @@ -26,6 +27,9 @@
max_system_len = 20000


ProvidersWithFiles: list[Provider] = ["openai", "anthropic", "openrouter"]


@dataclass(frozen=True, eq=False)
class Message:
"""
Expand Down Expand Up @@ -75,10 +79,11 @@ def replace(self, **kwargs) -> Self:
return dataclasses.replace(self, **kwargs)

def _content_files_list(
self, openai: bool = False, anthropic: bool = False
self,
provider: Provider,
) -> list[dict[str, Any]]:
# only these providers support files in the content
if not openai and not anthropic:
if provider not in ProvidersWithFiles:
raise ValueError("Provider does not support files in the content")

# combines a content message with a list of files
Expand Down Expand Up @@ -121,7 +126,7 @@ def _content_files_list(
)
continue

if anthropic:
if provider == "anthropic":
content.append(
{
"type": "image",
Expand All @@ -132,7 +137,7 @@ def _content_files_list(
},
}
)
elif openai:
elif provider == "openai":
# OpenAI format
content.append(
{
Expand All @@ -147,12 +152,13 @@ def _content_files_list(

return content

def to_dict(self, keys=None, openai=False, anthropic=False) -> dict:
def to_dict(self, keys=None, provider: Provider | None = None) -> dict:
"""Return a dict representation of the message, serializable to JSON."""
content: str | list[dict[str, Any]]
if anthropic or openai:
# OpenAI format or Anthropic format should include files in the content
content = self._content_files_list(openai=openai, anthropic=anthropic)
if provider in ProvidersWithFiles:
# OpenAI/Anthropic format should include files in the content
# Some otherwise OpenAI-compatible providers (groq, deepseek?) do not support this
content = self._content_files_list(provider)
else:
# storage/wire format should keep the content as a string
content = self.content
Expand Down Expand Up @@ -346,12 +352,9 @@ def toml_to_msgs(toml: str) -> list[Message]:
]


def msgs2dicts(msgs: list[Message], openai=False, anthropic=False) -> list[dict]:
def msgs2dicts(msgs: list[Message], provider: Provider) -> list[dict]:
"""Convert a list of Message objects to a list of dicts ready to pass to an LLM."""
return [
msg.to_dict(keys=["role", "content"], openai=openai, anthropic=anthropic)
for msg in msgs
]
return [msg.to_dict(keys=["role", "content"], provider=provider) for msg in msgs]


# TODO: remove model assumption
Expand Down

0 comments on commit cea30cf

Please sign in to comment.