diff --git a/gptme/constants.py b/gptme/constants.py index df3d58ec..6c569573 100644 --- a/gptme/constants.py +++ b/gptme/constants.py @@ -1,3 +1,13 @@ +""" +Constants +""" + +# Optimized for code +# Discussion here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683 +# TODO: make these configurable +TEMPERATURE = 0 +TOP_P = 0.1 + # prefix for commands, e.g. /help CMDFIX = "/" diff --git a/gptme/llm.py b/gptme/llm.py index 77bda650..c461f8d4 100644 --- a/gptme/llm.py +++ b/gptme/llm.py @@ -1,67 +1,45 @@ import logging import shutil import sys -from collections.abc import Generator, Iterator +from collections.abc import Iterator +from typing import Literal -from anthropic import Anthropic -from openai import AzureOpenAI, OpenAI from rich import print +from .llm_anthropic import chat as chat_anthropic +from .llm_anthropic import get_client as get_anthropic_client +from .llm_anthropic import init as init_anthropic +from .llm_anthropic import stream as stream_anthropic +from .llm_openai import chat as chat_openai +from .llm_openai import get_client as get_openai_client +from .llm_openai import init as init_openai +from .llm_openai import stream as stream_openai from .config import get_config from .constants import PROMPT_ASSISTANT -from .message import Message, len_tokens, msgs2dicts +from .message import Message, len_tokens from .models import MODELS, get_summary_model from .util import extract_codeblocks -# Optimized for code -# Discussion here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683 -# TODO: make these configurable -temperature = 0 -top_p = 0.1 - logger = logging.getLogger(__name__) -oai_client: OpenAI | None = None -anthropic_client: Anthropic | None = None +Provider = Literal["openai", "anthropic", "azure", "openrouter", "local"] -def init_llm(llm: str): - global oai_client, anthropic_client +def init_llm(llm: str): # set up API_KEY (if openai) and API_BASE (if local) config = get_config() - if llm == "openai": - api_key = config.get_env_required("OPENAI_API_KEY") - oai_client = OpenAI(api_key=api_key) - elif llm == "azure": - api_key = config.get_env_required("AZURE_OPENAI_API_KEY") - azure_endpoint = config.get_env_required("AZURE_OPENAI_ENDPOINT") - oai_client = AzureOpenAI( - api_key=api_key, - api_version="2023-07-01-preview", - azure_endpoint=azure_endpoint, - ) + if llm in ["openai", "azure", "openrouter", "local"]: + init_openai(llm, config) + assert get_openai_client() elif llm == "anthropic": - api_key = config.get_env_required("ANTHROPIC_API_KEY") - anthropic_client = Anthropic( - api_key=api_key, - max_retries=5, - ) - elif llm == "openrouter": - api_key = config.get_env_required("OPENROUTER_API_KEY") - oai_client = OpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1") - elif llm == "local": - api_base = config.get_env_required("OPENAI_API_BASE") - api_key = config.get_env("OPENAI_API_KEY") or "ollama" - oai_client = OpenAI(api_key=api_key, base_url=api_base) + init_anthropic(config) + assert get_anthropic_client() else: print(f"Error: Unknown LLM: {llm}") sys.exit(1) - # ensure we have initialized the client - assert oai_client or anthropic_client - def reply(messages: list[Message], model: str, stream: bool = False) -> Message: if stream: @@ -74,128 +52,26 @@ def reply(messages: list[Message], model: str, stream: bool = False) -> Message: return Message("assistant", response) -def _chat_complete_openai(messages: list[Message], model: str) -> str: - # This will generate code and such, so we need appropriate temperature and top_p params - # top_p controls diversity, temperature controls randomness - assert oai_client, "LLM not initialized" - response = oai_client.chat.completions.create( - model=model, - messages=msgs2dicts(messages, openai=True), # type: ignore - temperature=temperature, - top_p=top_p, - ) - content = response.choices[0].message.content - assert content - return content - - -def _chat_complete_anthropic(messages: list[Message], model: str) -> str: - assert anthropic_client, "LLM not initialized" - messages, system_message = _transform_system_messages_anthropic(messages) - response = anthropic_client.messages.create( - model=model, - messages=msgs2dicts(messages, anthropic=True), # type: ignore - system=system_message, - temperature=temperature, - top_p=top_p, - max_tokens=4096, - ) - content = response.content - assert content - assert len(content) == 1 - return content[0].text # type: ignore - - def _chat_complete(messages: list[Message], model: str) -> str: - if oai_client: - return _chat_complete_openai(messages, model) - elif anthropic_client: - return _chat_complete_anthropic(messages, model) + provider = _client_to_provider() + if provider == "openai": + return chat_openai(messages, model) + elif provider == "anthropic": + return chat_anthropic(messages, model) else: raise ValueError("LLM not initialized") -def _transform_system_messages_anthropic( - messages: list[Message], -) -> tuple[list[Message], str]: - # transform system messages into system kwarg for anthropic - # for first system message, transform it into a system kwarg - assert messages[0].role == "system" - system_prompt = messages[0].content - messages.pop(0) - - # for any subsequent system messages, transform them into a message - for i, message in enumerate(messages): - if message.role == "system": - messages[i] = Message( - "user", - content=f"{message.content}", - ) - - # find consecutive user role messages and merge them into a single message - messages_new: list[Message] = [] - while messages: - message = messages.pop(0) - if messages_new and messages_new[-1].role == "user": - messages_new[-1] = Message( - "user", - content=f"{messages_new[-1].content}\n{message.content}", - ) - else: - messages_new.append(message) - messages = messages_new - - return messages, system_prompt - - def _stream(messages: list[Message], model: str) -> Iterator[str]: - if oai_client: - return _stream_openai(messages, model) - elif anthropic_client: - return _stream_anthropic(messages, model) + provider = _client_to_provider() + if provider == "openai": + return stream_openai(messages, model) + elif provider == "anthropic": + return stream_anthropic(messages, model) else: raise ValueError("LLM not initialized") -def _stream_openai(messages: list[Message], model: str) -> Generator[str, None, None]: - assert oai_client, "LLM not initialized" - stop_reason = None - for chunk in oai_client.chat.completions.create( - model=model, - messages=msgs2dicts(messages, openai=True), # type: ignore - temperature=temperature, - top_p=top_p, - stream=True, - # the llama-cpp-python server needs this explicitly set, otherwise unreliable results - # TODO: make this better - max_tokens=1000 if not model.startswith("gpt-") else 4096, - ): - if not chunk.choices: # type: ignore - # Got a chunk with no choices, Azure always sends one of these at the start - continue - stop_reason = chunk.choices[0].finish_reason # type: ignore - content = chunk.choices[0].delta.content # type: ignore - if content: - yield content - logger.debug(f"Stop reason: {stop_reason}") - - -def _stream_anthropic( - messages: list[Message], model: str -) -> Generator[str, None, None]: - messages, system_prompt = _transform_system_messages_anthropic(messages) - assert anthropic_client, "LLM not initialized" - with anthropic_client.messages.stream( - model=model, - messages=msgs2dicts(messages, anthropic=True), # type: ignore - system=system_prompt, - temperature=temperature, - top_p=top_p, - max_tokens=4096, - ) as stream: - yield from stream.text_stream - - def _reply_stream(messages: list[Message], model: str) -> Message: print(f"{PROMPT_ASSISTANT}: Thinking...", end="\r") @@ -236,11 +112,14 @@ def print_clear(): return Message("assistant", output) -def _client_to_provider() -> str: - if oai_client: - if "openai" in oai_client.base_url.host: +def _client_to_provider() -> Provider: + openai_client = get_openai_client() + anthropic_client = get_anthropic_client() + assert openai_client or anthropic_client, "No client initialized" + if openai_client: + if "openai" in openai_client.base_url.host: return "openai" - elif "openrouter" in oai_client.base_url.host: + elif "openrouter" in openai_client.base_url.host: return "openrouter" else: return "azure" @@ -265,8 +144,9 @@ def summarize(content: str) -> str: Message("user", content=f"Summarize this:\n{content}"), ] - model = get_summary_model(_client_to_provider()) - context_limit = MODELS["openai" if oai_client else "anthropic"][model]["context"] + provider = _client_to_provider() + model = get_summary_model(provider) + context_limit = MODELS[provider][model]["context"] if len_tokens(messages) > context_limit: raise ValueError( f"Cannot summarize more than {context_limit} tokens, got {len_tokens(messages)}" diff --git a/gptme/llm_anthropic.py b/gptme/llm_anthropic.py new file mode 100644 index 00000000..c65f1044 --- /dev/null +++ b/gptme/llm_anthropic.py @@ -0,0 +1,85 @@ +from collections.abc import Generator + +from anthropic import Anthropic + +from .constants import TEMPERATURE, TOP_P +from .message import Message, msgs2dicts + +anthropic: Anthropic | None = None + + +def init(config): + global anthropic + api_key = config.get_env_required("ANTHROPIC_API_KEY") + anthropic = Anthropic( + api_key=api_key, + max_retries=5, + ) + + +def get_client() -> Anthropic | None: + return anthropic + + +def chat(messages: list[Message], model: str) -> str: + assert anthropic, "LLM not initialized" + messages, system_messages = _transform_system_messages(messages) + response = anthropic.messages.create( + model=model, + messages=msgs2dicts(messages, anthropic=True), # type: ignore + system=system_messages, + temperature=TEMPERATURE, + top_p=TOP_P, + max_tokens=4096, + ) + content = response.content + assert content + assert len(content) == 1 + return content[0].text # type: ignore + + +def stream(messages: list[Message], model: str) -> Generator[str, None, None]: + messages, system_messages = _transform_system_messages(messages) + assert anthropic, "LLM not initialized" + with anthropic.messages.stream( + model=model, + messages=msgs2dicts(messages, anthropic=True), # type: ignore + system=system_messages, + temperature=TEMPERATURE, + top_p=TOP_P, + max_tokens=4096, + ) as stream: + yield from stream.text_stream + + +def _transform_system_messages( + messages: list[Message], +) -> tuple[list[Message], str]: + # transform system messages into system kwarg for anthropic + # for first system message, transform it into a system kwarg + assert messages[0].role == "system" + system_prompt = messages[0].content + messages.pop(0) + + # for any subsequent system messages, transform them into a message + for i, message in enumerate(messages): + if message.role == "system": + messages[i] = Message( + "user", + content=f"{message.content}", + ) + + # find consecutive user role messages and merge them into a single message + messages_new: list[Message] = [] + while messages: + message = messages.pop(0) + if messages_new and messages_new[-1].role == "user": + messages_new[-1] = Message( + "user", + content=f"{messages_new[-1].content}\n{message.content}", + ) + else: + messages_new.append(message) + messages = messages_new + + return messages, system_prompt diff --git a/gptme/llm_openai.py b/gptme/llm_openai.py new file mode 100644 index 00000000..a9a0ab0e --- /dev/null +++ b/gptme/llm_openai.py @@ -0,0 +1,79 @@ +import logging +from collections.abc import Generator + +from openai import AzureOpenAI, OpenAI + +from .constants import TEMPERATURE, TOP_P +from .message import Message, msgs2dicts + +openai: OpenAI | None = None +logger = logging.getLogger(__name__) + + +def init(llm: str, config): + global openai + + if llm == "openai": + api_key = config.get_env_required("OPENAI_API_KEY") + openai = OpenAI(api_key=api_key) + elif llm == "azure": + api_key = config.get_env_required("AZURE_OPENAI_API_KEY") + azure_endpoint = config.get_env_required("AZURE_OPENAI_ENDPOINT") + openai = AzureOpenAI( + api_key=api_key, + api_version="2023-07-01-preview", + azure_endpoint=azure_endpoint, + ) + elif llm == "openrouter": + api_key = config.get_env_required("OPENROUTER_API_KEY") + openai = OpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1") + elif llm == "local": + api_base = config.get_env_required("OPENAI_API_BASE") + api_key = config.get_env("OPENAI_API_KEY") or "ollama" + openai = OpenAI(api_key=api_key, base_url=api_base) + else: + raise ValueError(f"Unknown LLM: {llm}") + + assert openai, "LLM not initialized" + + +def get_client() -> OpenAI | None: + return openai + + +def chat(messages: list[Message], model: str) -> str: + # This will generate code and such, so we need appropriate temperature and top_p params + # top_p controls diversity, temperature controls randomness + assert openai, "LLM not initialized" + response = openai.chat.completions.create( + model=model, + messages=msgs2dicts(messages, openai=True), # type: ignore + temperature=TEMPERATURE, + top_p=TOP_P, + ) + content = response.choices[0].message.content + assert content + return content + + +def stream(messages: list[Message], model: str) -> Generator[str, None, None]: + assert openai, "LLM not initialized" + stop_reason = None + for chunk in openai.chat.completions.create( + model=model, + messages=msgs2dicts(messages, openai=True), # type: ignore + temperature=TEMPERATURE, + top_p=TOP_P, + stream=True, + # the llama-cpp-python server needs this explicitly set, otherwise unreliable results + # TODO: make this better + max_tokens=1000 if not model.startswith("gpt-") else 4096, + ): + if not chunk.choices: # type: ignore + # Got a chunk with no choices, Azure always sends one of these at the start + continue + stop_reason = chunk.choices[0].finish_reason # type: ignore + content = chunk.choices[0].delta.content # type: ignore + if content: + yield content + logger.debug(f"Stop reason: {stop_reason}")