Skip to content

Commit

Permalink
feat: added support for groq provider
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Oct 25, 2024
1 parent 2d8b602 commit 4299cd0
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 5 deletions.
18 changes: 18 additions & 0 deletions docs/providers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ To use OpenRouter, set your API key:
export OPENROUTER_API_KEY="your-api-key"
Groq
----

To use Groq, set your API key:

.. code-block:: sh
export GROQ_API_KEY="your-api-key"
xAI
---

To use xAI, set your API key:

.. code-block:: sh
export XAI_API_KEY="your-api-key"
Local
-----

Expand Down
11 changes: 9 additions & 2 deletions gptme/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from collections.abc import Iterator
from functools import lru_cache
from typing import cast

from rich import print

Expand All @@ -17,7 +18,12 @@
from .llm_openai import init as init_openai
from .llm_openai import stream as stream_openai
from .message import Message, format_msgs, len_tokens
from .models import MODELS, Provider, get_summary_model
from .models import (
MODELS,
PROVIDERS_OPENAI,
Provider,
get_summary_model,
)
from .tools import ToolUse

logger = logging.getLogger(__name__)
Expand All @@ -27,7 +33,8 @@ def init_llm(llm: str):
# set up API_KEY (if openai) and API_BASE (if local)
config = get_config()

if llm in ["openai", "azure", "openrouter", "local", "xai"]:
llm = cast(Provider, llm)
if llm in PROVIDERS_OPENAI:
init_openai(llm, config)
assert get_openai_client()
elif llm == "anthropic":
Expand Down
6 changes: 5 additions & 1 deletion gptme/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .config import Config
from .constants import TEMPERATURE, TOP_P
from .message import Message, msgs2dicts
from .models import Provider

if TYPE_CHECKING:
from openai import OpenAI
Expand All @@ -21,7 +22,7 @@
}


def init(provider: str, config: Config):
def init(provider: Provider, config: Config):
global openai
from openai import AzureOpenAI, OpenAI # fmt: skip

Expand All @@ -42,6 +43,9 @@ def init(provider: str, config: Config):
elif provider == "xai":
api_key = config.get_env_required("XAI_API_KEY")
openai = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
elif provider == "groq":
api_key = config.get_env_required("GROQ_API_KEY")
openai = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
elif provider == "local":
# OPENAI_API_BASE renamed to OPENAI_BASE_URL: https://github.com/openai/openai-python/issues/745
api_base = config.get_env("OPENAI_API_BASE")
Expand Down
6 changes: 4 additions & 2 deletions gptme/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ class _ModelDictMeta(TypedDict):


# available providers
Provider = Literal["openai", "anthropic", "azure", "openrouter", "xai", "local"]
PROVIDERS = get_args(Provider)
Provider = Literal["openai", "anthropic", "azure", "openrouter", "groq", "xai", "local"]
PROVIDERS: list[Provider] = cast(list[Provider], get_args(Provider))
PROVIDERS_OPENAI: list[Provider]
PROVIDERS_OPENAI = ["openai", "azure", "openrouter", "xai", "groq", "local"]

# default model
DEFAULT_MODEL: ModelMeta | None = None
Expand Down

0 comments on commit 4299cd0

Please sign in to comment.