Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: basic support for openai/o1-preview and openai/o1-mini #117

Merged
merged 2 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def chat(
# init
init(model, interactive)

if model and model.startswith("openai/o1") and stream:
logger.info("Disabled streaming for OpenAI's O1 (streaming not supported)")
stream = False

# we need to run this before checking stdin, since the interactive doesn't work with the switch back to interactive mode
logfile = get_logfile(
name, interactive=(not prompt_msgs and interactive) and sys.stdin.isatty()
Expand Down
27 changes: 23 additions & 4 deletions gptme/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,32 @@ def get_client() -> "OpenAI | None":
return openai


def _prep_o1(msgs: list[Message]) -> Generator[Message, None, None]:
# prepare messages for OpenAI O1, which doesn't support the system role
# and requires the first message to be from the user
for msg in msgs:
if msg.role == "system":
msg.role = "user"
msg.content = f"<system>\n{msg.content}\n</system>"
yield msg


def chat(messages: list[Message], model: str) -> str:
# This will generate code and such, so we need appropriate temperature and top_p params
# top_p controls diversity, temperature controls randomness
assert openai, "LLM not initialized"
is_o1 = model.startswith("o1")
if is_o1:
messages = list(_prep_o1(messages))

# noreorder
from openai._types import NOT_GIVEN # fmt: skip

response = openai.chat.completions.create(
model=model,
messages=msgs2dicts(messages, openai=True), # type: ignore
temperature=TEMPERATURE,
top_p=TOP_P,
temperature=TEMPERATURE if not is_o1 else NOT_GIVEN,
top_p=TOP_P if not is_o1 else NOT_GIVEN,
extra_headers=(
openrouter_headers if "openrouter.ai" in str(openai.base_url) else {}
),
Expand All @@ -74,13 +91,15 @@ def stream(messages: list[Message], model: str) -> Generator[str, None, None]:
stop_reason = None
for chunk in openai.chat.completions.create(
model=model,
messages=msgs2dicts(messages, openai=True), # type: ignore
messages=msgs2dicts(_prep_o1(messages), openai=True), # type: ignore
temperature=TEMPERATURE,
top_p=TOP_P,
stream=True,
# the llama-cpp-python server needs this explicitly set, otherwise unreliable results
# TODO: make this better
max_tokens=1000 if not model.startswith("gpt-") else 4096,
# max_tokens=(
# (1000 if not model.startswith("gpt-") else 4096)
# ),
extra_headers=(
openrouter_headers if "openrouter.ai" in str(openai.base_url) else {}
),
Expand Down
1 change: 1 addition & 0 deletions gptme/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
logger = logging.getLogger(__name__)


# TODO: make immutable/dataclass
class Message:
"""A message in the assistant conversation."""

Expand Down
Loading