Skip to content

Commit

Permalink
fix: remove o1 streaming restriction
Browse files Browse the repository at this point in the history
The o1 models now support streaming in the OpenAI API, so we can remove
the special handling. The supports_streaming field was also not being
used correctly (checked for exact match instead of prefix).
  • Loading branch information
ErikBjare committed Nov 22, 2024
1 parent 990965f commit 7fd6ad5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 19 deletions.
11 changes: 6 additions & 5 deletions gptme/llm/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,10 @@ def chat(messages: list[Message], model: str) -> str:
if is_o1:
messages = list(_prep_o1(messages))

# noreorder
from openai._types import NOT_GIVEN # fmt: skip

messages_dicts = handle_files(msgs2dicts(messages))

from openai._types import NOT_GIVEN # fmt: skip

response = openai.chat.completions.create(
model=model,
messages=messages_dicts, # type: ignore
Expand All @@ -138,11 +137,13 @@ def stream(messages: list[Message], model: str) -> Generator[str, None, None]:

messages_dicts = handle_files(msgs2dicts(messages))

from openai._types import NOT_GIVEN # fmt: skip

for chunk in openai.chat.completions.create(
model=model,
messages=messages_dicts, # type: ignore
temperature=TEMPERATURE,
top_p=TOP_P,
temperature=TEMPERATURE if not is_o1 else NOT_GIVEN,
top_p=TOP_P if not is_o1 else NOT_GIVEN,
stream=True,
# the llama-cpp-python server needs this explicitly set, otherwise unreliable results
# TODO: make this better
Expand Down
17 changes: 3 additions & 14 deletions gptme/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ModelMeta:
class _ModelDictMeta(TypedDict):
context: int
max_output: NotRequired[int]
supports_streaming: NotRequired[bool]

# price in USD per 1M tokens
price_input: NotRequired[float]
Expand Down Expand Up @@ -96,18 +97,6 @@ def set_default_model(model: str) -> None:
DEFAULT_MODEL = modelmeta


def create_meta_model(provider, model, **kwargs):
if provider not in PROVIDERS_OPENAI:
return ModelMeta(
provider=provider,
model=model,
supports_streaming=provider != "openai" or model != "o1",
**kwargs,
)
else:
return ModelMeta(provider=provider, model=model, **kwargs)


def get_model(model: str | None = None) -> ModelMeta:
if model is None:
assert DEFAULT_MODEL, "Default model not set, set it with set_default_model()"
Expand All @@ -126,7 +115,7 @@ def get_model(model: str | None = None) -> ModelMeta:
logger.warning(
f"Unknown model {model} from {provider}, using fallback metadata"
)
return create_meta_model(provider, model, context=128_000)
return ModelMeta(provider, model, context=128_000)
else:
# try to find model in all providers
for provider in MODELS:
Expand All @@ -136,7 +125,7 @@ def get_model(model: str | None = None) -> ModelMeta:
logger.warning(f"Unknown model {model}, using fallback metadata")
return ModelMeta(provider="unknown", model=model, context=128_000)

return create_meta_model(provider, model, **MODELS[provider][model])
return ModelMeta(provider, model, **MODELS[provider][model])


def get_recommended_model(provider: Provider) -> str: # pragma: no cover
Expand Down

0 comments on commit 7fd6ad5

Please sign in to comment.