Skip to content

Commit

Permalink
feat: basic support for openai/o1-preview and openai/o1-mini (#117)
Browse files Browse the repository at this point in the history
* feat: basic support for openai/o1-preview and openai/o1-mini

* fix: made it possible to use both O1 and non-O1 models
  • Loading branch information
ErikBjare authored Sep 16, 2024
1 parent ae3ea89 commit cf13bae
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
4 changes: 4 additions & 0 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def chat(
# init
init(model, interactive)

if model and model.startswith("openai/o1") and stream:
logger.info("Disabled streaming for OpenAI's O1 (streaming not supported)")
stream = False

# we need to run this before checking stdin, since the interactive doesn't work with the switch back to interactive mode
logfile = get_logfile(
name, interactive=(not prompt_msgs and interactive) and sys.stdin.isatty()
Expand Down
27 changes: 23 additions & 4 deletions gptme/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,32 @@ def get_client() -> "OpenAI | None":
return openai


def _prep_o1(msgs: list[Message]) -> Generator[Message, None, None]:
# prepare messages for OpenAI O1, which doesn't support the system role
# and requires the first message to be from the user
for msg in msgs:
if msg.role == "system":
msg.role = "user"
msg.content = f"<system>\n{msg.content}\n</system>"
yield msg


def chat(messages: list[Message], model: str) -> str:
# This will generate code and such, so we need appropriate temperature and top_p params
# top_p controls diversity, temperature controls randomness
assert openai, "LLM not initialized"
is_o1 = model.startswith("o1")
if is_o1:
messages = list(_prep_o1(messages))

# noreorder
from openai._types import NOT_GIVEN # fmt: skip

response = openai.chat.completions.create(
model=model,
messages=msgs2dicts(messages, openai=True), # type: ignore
temperature=TEMPERATURE,
top_p=TOP_P,
temperature=TEMPERATURE if not is_o1 else NOT_GIVEN,
top_p=TOP_P if not is_o1 else NOT_GIVEN,
extra_headers=(
openrouter_headers if "openrouter.ai" in str(openai.base_url) else {}
),
Expand All @@ -74,13 +91,15 @@ def stream(messages: list[Message], model: str) -> Generator[str, None, None]:
stop_reason = None
for chunk in openai.chat.completions.create(
model=model,
messages=msgs2dicts(messages, openai=True), # type: ignore
messages=msgs2dicts(_prep_o1(messages), openai=True), # type: ignore
temperature=TEMPERATURE,
top_p=TOP_P,
stream=True,
# the llama-cpp-python server needs this explicitly set, otherwise unreliable results
# TODO: make this better
max_tokens=1000 if not model.startswith("gpt-") else 4096,
# max_tokens=(
# (1000 if not model.startswith("gpt-") else 4096)
# ),
extra_headers=(
openrouter_headers if "openrouter.ai" in str(openai.base_url) else {}
),
Expand Down
1 change: 1 addition & 0 deletions gptme/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
logger = logging.getLogger(__name__)


# TODO: make immutable/dataclass
class Message:
"""A message in the assistant conversation."""

Expand Down

0 comments on commit cf13bae

Please sign in to comment.