Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve model token limit detection #3292

Merged
merged 4 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@
)

# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
)

# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible
Expand Down
11 changes: 8 additions & 3 deletions backend/danswer/llm/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from langchain_core.prompt_values import PromptValue

from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
from danswer.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
from danswer.llm.interfaces import LLM
Expand Down Expand Up @@ -161,7 +163,9 @@ def _convert_delta_to_message_chunk(

if role == "user":
return HumanMessageChunk(content=content)
elif role == "assistant":
# NOTE: if tool calls are present, then it's an assistant.
# In Ollama, the role will be None for tool-calls
elif role == "assistant" or tool_calls:
if tool_calls:
tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
Expand Down Expand Up @@ -236,6 +240,7 @@ def __init__(
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
model_kwargs: dict[str, Any] | None = None,
long_term_logger: LongTermLogger | None = None,
):
self._timeout = timeout
Expand Down Expand Up @@ -268,7 +273,7 @@ def __init__(
for k, v in custom_config.items():
os.environ[k] = v

model_kwargs: dict[str, Any] = {}
model_kwargs = model_kwargs or {}
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:
Expand Down
13 changes: 13 additions & 0 deletions backend/danswer/llm/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any

from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider
Expand All @@ -13,6 +16,15 @@
from danswer.utils.long_term_log import LongTermLogger


def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
"""Ollama requires us to specify the max context window.

For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
TODO: allow model-specific values to be configured via the UI.
"""
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}


def get_main_llm_from_tuple(
llms: tuple[LLM, LLM],
) -> LLM:
Expand Down Expand Up @@ -132,5 +144,6 @@ def get_llm(
temperature=temperature,
custom_config=custom_config,
extra_headers=build_llm_extra_headers(additional_headers),
model_kwargs=_build_extra_model_kwargs(provider),
long_term_logger=long_term_logger,
)
91 changes: 74 additions & 17 deletions backend/danswer/llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import io
import json
from collections.abc import Callable
Expand Down Expand Up @@ -385,6 +386,62 @@ def test_llm(llm: LLM) -> str | None:
return error_msg


def get_model_map() -> dict:
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))

# NOTE: we could add additional models here in the future,
# but for now there is no point. Ollama allows the user to
# to specify their desired max context window, and it's
# unlikely to be standard across users even for the same model
# (it heavily depends on their hardware). For now, we'll just
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
# for model_name in [
# "llama3.2",
# "llama3.2:1b",
# "llama3.2:3b",
# "llama3.2:11b",
# "llama3.2:90b",
# ]:
# starting_map[f"ollama/{model_name}"] = {
# "max_tokens": 128000,
# "max_input_tokens": 128000,
# "max_output_tokens": 128000,
# }

return starting_map


def _strip_extra_provider_from_model_name(model_name: str) -> str:
return model_name.split("/")[1] if "/" in model_name else model_name


def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name


def _find_model_obj(
model_map: dict, provider: str, model_names: list[str | None]
) -> dict | None:
# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]

# First try all model names with provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(f"{provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {provider}/{model_name}")
return model_obj

# Then try all model names without provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
return model_obj

return None


def get_llm_max_tokens(
model_map: dict,
model_name: str,
Expand All @@ -397,22 +454,22 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS

try:
model_obj = model_map.get(f"{model_provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {model_provider}/{model_name}")

if not model_obj:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")

if not model_obj:
model_name_split = model_name.split("/")
if len(model_name_split) > 1:
model_obj = model_map.get(model_name_split[1])
if model_obj:
logger.debug(f"Using model object for {model_name_split[1]}")

extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
model_name
)
model_obj = _find_model_obj(
model_map,
model_provider,
[
model_name,
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
extra_provider_stripped_model_name,
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(extra_provider_stripped_model_name),
],
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
Expand Down Expand Up @@ -488,7 +545,7 @@ def get_max_input_tokens(
# `model_cost` dict is a named public interface:
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
# model_map is litellm.model_cost
litellm_model_map = litellm.model_cost
litellm_model_map = get_model_map()

input_toks = (
get_llm_max_tokens(
Expand Down
4 changes: 2 additions & 2 deletions backend/requirements/default.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trafilatura==1.12.2
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.50.2
litellm==1.53.1
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45
Expand All @@ -38,7 +38,7 @@ msal==1.28.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.52.2
openai==1.55.3
openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5
Expand Down
Loading