Skip to content

Commit

Permalink
refactor: refactored provider-specific code into new files llm_openai…
Browse files Browse the repository at this point in the history
….py and llm_anthropic.py
  • Loading branch information
ErikBjare committed Aug 14, 2024
1 parent 580cc36 commit eec8215
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 158 deletions.
10 changes: 10 additions & 0 deletions gptme/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
Constants
"""

# Optimized for code
# Discussion here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683
# TODO: make these configurable
TEMPERATURE = 0
TOP_P = 0.1

# prefix for commands, e.g. /help
CMDFIX = "/"

Expand Down
196 changes: 38 additions & 158 deletions gptme/llm.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,45 @@
import logging
import shutil
import sys
from collections.abc import Generator, Iterator
from collections.abc import Iterator
from typing import Literal

from anthropic import Anthropic
from openai import AzureOpenAI, OpenAI
from rich import print

from .llm_anthropic import chat as chat_anthropic
from .llm_anthropic import get_client as get_anthropic_client
from .llm_anthropic import init as init_anthropic
from .llm_anthropic import stream as stream_anthropic
from .llm_openai import chat as chat_openai
from .llm_openai import get_client as get_openai_client
from .llm_openai import init as init_openai
from .llm_openai import stream as stream_openai
from .config import get_config
from .constants import PROMPT_ASSISTANT
from .message import Message, len_tokens, msgs2dicts
from .message import Message, len_tokens
from .models import MODELS, get_summary_model
from .util import extract_codeblocks

# Optimized for code
# Discussion here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683
# TODO: make these configurable
temperature = 0
top_p = 0.1

logger = logging.getLogger(__name__)

oai_client: OpenAI | None = None
anthropic_client: Anthropic | None = None

Provider = Literal["openai", "anthropic", "azure", "openrouter", "local"]

def init_llm(llm: str):
global oai_client, anthropic_client

def init_llm(llm: str):
# set up API_KEY (if openai) and API_BASE (if local)
config = get_config()

if llm == "openai":
api_key = config.get_env_required("OPENAI_API_KEY")
oai_client = OpenAI(api_key=api_key)
elif llm == "azure":
api_key = config.get_env_required("AZURE_OPENAI_API_KEY")
azure_endpoint = config.get_env_required("AZURE_OPENAI_ENDPOINT")
oai_client = AzureOpenAI(
api_key=api_key,
api_version="2023-07-01-preview",
azure_endpoint=azure_endpoint,
)
if llm in ["openai", "azure", "openrouter", "local"]:
init_openai(llm, config)
assert get_openai_client()
elif llm == "anthropic":
api_key = config.get_env_required("ANTHROPIC_API_KEY")
anthropic_client = Anthropic(
api_key=api_key,
max_retries=5,
)
elif llm == "openrouter":
api_key = config.get_env_required("OPENROUTER_API_KEY")
oai_client = OpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
elif llm == "local":
api_base = config.get_env_required("OPENAI_API_BASE")
api_key = config.get_env("OPENAI_API_KEY") or "ollama"
oai_client = OpenAI(api_key=api_key, base_url=api_base)
init_anthropic(config)
assert get_anthropic_client()
else:
print(f"Error: Unknown LLM: {llm}")
sys.exit(1)

# ensure we have initialized the client
assert oai_client or anthropic_client


def reply(messages: list[Message], model: str, stream: bool = False) -> Message:
if stream:
Expand All @@ -74,128 +52,26 @@ def reply(messages: list[Message], model: str, stream: bool = False) -> Message:
return Message("assistant", response)


def _chat_complete_openai(messages: list[Message], model: str) -> str:
# This will generate code and such, so we need appropriate temperature and top_p params
# top_p controls diversity, temperature controls randomness
assert oai_client, "LLM not initialized"
response = oai_client.chat.completions.create(
model=model,
messages=msgs2dicts(messages, openai=True), # type: ignore
temperature=temperature,
top_p=top_p,
)
content = response.choices[0].message.content
assert content
return content


def _chat_complete_anthropic(messages: list[Message], model: str) -> str:
assert anthropic_client, "LLM not initialized"
messages, system_message = _transform_system_messages_anthropic(messages)
response = anthropic_client.messages.create(
model=model,
messages=msgs2dicts(messages, anthropic=True), # type: ignore
system=system_message,
temperature=temperature,
top_p=top_p,
max_tokens=4096,
)
content = response.content
assert content
assert len(content) == 1
return content[0].text # type: ignore


def _chat_complete(messages: list[Message], model: str) -> str:
if oai_client:
return _chat_complete_openai(messages, model)
elif anthropic_client:
return _chat_complete_anthropic(messages, model)
provider = _client_to_provider()
if provider == "openai":
return chat_openai(messages, model)
elif provider == "anthropic":
return chat_anthropic(messages, model)
else:
raise ValueError("LLM not initialized")


def _transform_system_messages_anthropic(
messages: list[Message],
) -> tuple[list[Message], str]:
# transform system messages into system kwarg for anthropic
# for first system message, transform it into a system kwarg
assert messages[0].role == "system"
system_prompt = messages[0].content
messages.pop(0)

# for any subsequent system messages, transform them into a <system> message
for i, message in enumerate(messages):
if message.role == "system":
messages[i] = Message(
"user",
content=f"<system>{message.content}</system>",
)

# find consecutive user role messages and merge them into a single <system> message
messages_new: list[Message] = []
while messages:
message = messages.pop(0)
if messages_new and messages_new[-1].role == "user":
messages_new[-1] = Message(
"user",
content=f"{messages_new[-1].content}\n{message.content}",
)
else:
messages_new.append(message)
messages = messages_new

return messages, system_prompt


def _stream(messages: list[Message], model: str) -> Iterator[str]:
if oai_client:
return _stream_openai(messages, model)
elif anthropic_client:
return _stream_anthropic(messages, model)
provider = _client_to_provider()
if provider == "openai":
return stream_openai(messages, model)
elif provider == "anthropic":
return stream_anthropic(messages, model)
else:
raise ValueError("LLM not initialized")


def _stream_openai(messages: list[Message], model: str) -> Generator[str, None, None]:
assert oai_client, "LLM not initialized"
stop_reason = None
for chunk in oai_client.chat.completions.create(
model=model,
messages=msgs2dicts(messages, openai=True), # type: ignore
temperature=temperature,
top_p=top_p,
stream=True,
# the llama-cpp-python server needs this explicitly set, otherwise unreliable results
# TODO: make this better
max_tokens=1000 if not model.startswith("gpt-") else 4096,
):
if not chunk.choices: # type: ignore
# Got a chunk with no choices, Azure always sends one of these at the start
continue
stop_reason = chunk.choices[0].finish_reason # type: ignore
content = chunk.choices[0].delta.content # type: ignore
if content:
yield content
logger.debug(f"Stop reason: {stop_reason}")


def _stream_anthropic(
messages: list[Message], model: str
) -> Generator[str, None, None]:
messages, system_prompt = _transform_system_messages_anthropic(messages)
assert anthropic_client, "LLM not initialized"
with anthropic_client.messages.stream(
model=model,
messages=msgs2dicts(messages, anthropic=True), # type: ignore
system=system_prompt,
temperature=temperature,
top_p=top_p,
max_tokens=4096,
) as stream:
yield from stream.text_stream


def _reply_stream(messages: list[Message], model: str) -> Message:
print(f"{PROMPT_ASSISTANT}: Thinking...", end="\r")

Expand Down Expand Up @@ -236,11 +112,14 @@ def print_clear():
return Message("assistant", output)


def _client_to_provider() -> str:
if oai_client:
if "openai" in oai_client.base_url.host:
def _client_to_provider() -> Provider:
openai_client = get_openai_client()
anthropic_client = get_anthropic_client()
assert openai_client or anthropic_client, "No client initialized"
if openai_client:
if "openai" in openai_client.base_url.host:
return "openai"
elif "openrouter" in oai_client.base_url.host:
elif "openrouter" in openai_client.base_url.host:
return "openrouter"
else:
return "azure"
Expand All @@ -265,8 +144,9 @@ def summarize(content: str) -> str:
Message("user", content=f"Summarize this:\n{content}"),
]

model = get_summary_model(_client_to_provider())
context_limit = MODELS["openai" if oai_client else "anthropic"][model]["context"]
provider = _client_to_provider()
model = get_summary_model(provider)
context_limit = MODELS[provider][model]["context"]
if len_tokens(messages) > context_limit:
raise ValueError(
f"Cannot summarize more than {context_limit} tokens, got {len_tokens(messages)}"
Expand Down
85 changes: 85 additions & 0 deletions gptme/llm_anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from collections.abc import Generator

from anthropic import Anthropic

from .constants import TEMPERATURE, TOP_P
from .message import Message, msgs2dicts

anthropic: Anthropic | None = None


def init(config):
global anthropic
api_key = config.get_env_required("ANTHROPIC_API_KEY")
anthropic = Anthropic(
api_key=api_key,
max_retries=5,
)


def get_client() -> Anthropic | None:
return anthropic


def chat(messages: list[Message], model: str) -> str:
assert anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)
response = anthropic.messages.create(
model=model,
messages=msgs2dicts(messages, anthropic=True), # type: ignore
system=system_messages,
temperature=TEMPERATURE,
top_p=TOP_P,
max_tokens=4096,
)
content = response.content
assert content
assert len(content) == 1
return content[0].text # type: ignore


def stream(messages: list[Message], model: str) -> Generator[str, None, None]:
messages, system_messages = _transform_system_messages(messages)
assert anthropic, "LLM not initialized"
with anthropic.messages.stream(
model=model,
messages=msgs2dicts(messages, anthropic=True), # type: ignore
system=system_messages,
temperature=TEMPERATURE,
top_p=TOP_P,
max_tokens=4096,
) as stream:
yield from stream.text_stream


def _transform_system_messages(
messages: list[Message],
) -> tuple[list[Message], str]:
# transform system messages into system kwarg for anthropic
# for first system message, transform it into a system kwarg
assert messages[0].role == "system"
system_prompt = messages[0].content
messages.pop(0)

# for any subsequent system messages, transform them into a <system> message
for i, message in enumerate(messages):
if message.role == "system":
messages[i] = Message(
"user",
content=f"<system>{message.content}</system>",
)

# find consecutive user role messages and merge them into a single <system> message
messages_new: list[Message] = []
while messages:
message = messages.pop(0)
if messages_new and messages_new[-1].role == "user":
messages_new[-1] = Message(
"user",
content=f"{messages_new[-1].content}\n{message.content}",
)
else:
messages_new.append(message)
messages = messages_new

return messages, system_prompt
Loading

0 comments on commit eec8215

Please sign in to comment.