diff --git a/gptme/llm/models.py b/gptme/llm/models.py index 595fb86f..1a2bd76b 100644 --- a/gptme/llm/models.py +++ b/gptme/llm/models.py @@ -156,6 +156,10 @@ class _ModelDictMeta(TypedDict): } +def get_default_model() -> ModelMeta | None: + return DEFAULT_MODEL + + def set_default_model(model: str) -> None: modelmeta = get_model(model) assert modelmeta diff --git a/gptme/util/reduce.py b/gptme/util/reduce.py index e88e8bf2..1600d63f 100644 --- a/gptme/util/reduce.py +++ b/gptme/util/reduce.py @@ -8,7 +8,7 @@ from collections.abc import Generator from ..codeblock import Codeblock -from ..llm.models import DEFAULT_MODEL, get_model +from ..llm.models import get_default_model, get_model from ..message import Message, len_tokens logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ def reduce_log( ) -> Generator[Message, None, None]: """Reduces log until it is below `limit` tokens by continually summarizing the longest messages until below the limit.""" # get the token limit - model = DEFAULT_MODEL or get_model("gpt-4") + model = get_default_model() or get_model("gpt-4") if limit is None: limit = 0.9 * model.context