Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quest arg generation should always use gpt-4 #84

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from generators.server_settings_field_generator import ServerSettingsFieldGenerator
from generators.utils import safe_format
from schema.quest import QuestDescription
from utils.context_utils import get_quest_arc_generator


class AdventureFixedQuestArcGenerator(ServerSettingsFieldGenerator):
Expand Down Expand Up @@ -54,6 +55,11 @@ def inner_generate(
context: AgentContext,
generation_config: Optional[dict] = None,
) -> Block:

# NOTE: We use a specific generator for this because it's important it's gpt-4
# otherwise the prompts have been a bit too shaky.
generator = get_quest_arc_generator(context)

prompt = self.prompt

task = generator.generate(
Expand Down
2 changes: 1 addition & 1 deletion src/schema/server_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class ServerSettings(BaseModel):
"""

# Language Generation Settings - Function calling
default_function_capable_llm_model: str = Field("gpt-3.5-turbo", description="")
default_function_capable_llm_model: str = Field("gpt-4", description="")
default_function_capable_llm_temperature: float = Field(0.0, description="")
default_function_capable_llm_max_tokens: int = Field(512, description="")

Expand Down
69 changes: 60 additions & 9 deletions src/utils/context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import List, Optional, Union

from steamship import Block, PluginInstance
from steamship.agents.llms.openai import ChatOpenAI
from steamship.agents.logging import AgentLogging
from steamship.agents.schema import ChatHistory, ChatLLM, FinishAction
from steamship.agents.schema.agent import AgentContext
Expand All @@ -38,6 +37,7 @@
)
_BACKGROUND_MUSIC_GENERATOR_KEY = "background-music-generator"
_NARRATION_GENERATOR_KEY = "narration-generator"
_QUEST_GENERATOR_KEY = "quest-arc-generator"
_SERVER_SETTINGS_KEY = "server-settings"
_GAME_STATE_KEY = "user-settings"

Expand Down Expand Up @@ -101,6 +101,33 @@ def get_story_text_generator(
return generator


def get_quest_arc_generator(
context: AgentContext, default: Optional[PluginInstance] = None
) -> Optional[PluginInstance]:
generator = context.metadata.get(_QUEST_GENERATOR_KEY, default)

if not generator:
# Lazily create
server_settings: ServerSettings = get_server_settings(context)

# Hard coded.
model_name = "gpt-4"
plugin_handle = "gpt-4"

generator = context.client.use_plugin(
plugin_handle,
config={
"model": model_name,
"max_tokens": server_settings.default_story_max_tokens,
"temperature": server_settings.default_story_temperature,
},
)

context.metadata[_QUEST_GENERATOR_KEY] = generator

return generator


def get_background_music_generator(
context: AgentContext, default: Optional[PluginInstance] = None
) -> Optional[PluginInstance]:
Expand Down Expand Up @@ -343,14 +370,38 @@ def switch_history_to_current_quest(


def get_function_capable_llm(
context: AgentContext, default: Optional[ChatLLM] = None # noqa: F821
) -> Optional[ChatLLM]: # noqa: F821
llm = context.metadata.get(_FUNCTION_CAPABLE_LLM, default)
if not llm:
# Lazy create
llm = ChatOpenAI(context.client)
context.metadata[_FUNCTION_CAPABLE_LLM] = llm
return llm
context: AgentContext, default: Optional[PluginInstance] = None
) -> Optional[PluginInstance]:
generator = context.metadata.get(_FUNCTION_CAPABLE_LLM, default)

if not generator:
# Lazily create
server_settings: ServerSettings = get_server_settings(context)

open_ai_models = ["gpt-3.5-turbo", "gpt-4-1106-preview", "gpt-4"]

model_name = server_settings._select_model(
open_ai_models,
default=server_settings.default_function_capable_llm_model,
preferred="gpt-4",
)

plugin_handle = None
if model_name in open_ai_models:
plugin_handle = "gpt-4"

generator = context.client.use_plugin(
plugin_handle,
config={
"model": model_name,
"max_tokens": server_settings.default_function_capable_llm_max_tokens,
"temperature": server_settings.default_function_capable_llm_temperature,
},
)

context.metadata[_FUNCTION_CAPABLE_LLM] = generator

return generator


def _key_for_question(blocks: List[Block], key: Optional[str] = None) -> str:
Expand Down