Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom tool headers #2773

Merged
merged 6 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions backend/danswer/chat/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from typing import cast

from fastapi.datastructures import Headers
from sqlalchemy.orm import Session

from danswer.chat.models import CitationInfo
Expand Down Expand Up @@ -166,3 +167,31 @@ def slack_link_format(match: re.Match) -> str:
new_citation_info[citation.citation_num] = citation

return new_answer, list(new_citation_info.values())


def extract_headers(
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
) -> dict[str, str]:
"""
Extract headers specified in pass_through_headers from input headers.
Handles both dict and FastAPI Headers objects, accounting for lowercase keys.

Args:
headers: Input headers as dict or Headers object.

Returns:
dict: Filtered headers based on pass_through_headers.
"""
if not pass_through_headers:
return {}

extracted_headers: dict[str, str] = {}
for key in pass_through_headers:
if key in headers:
extracted_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers
3 changes: 3 additions & 0 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def stream_chat_message_objects(
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
Comment on lines +279 to 280
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: New parameter tool_additional_headers added but not used in the function body

enforce_chat_session_id_for_search_docs: bool = True,
) -> ChatPacketStream:
Expand Down Expand Up @@ -862,6 +863,7 @@ def stream_chat_message(
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
Expand All @@ -871,6 +873,7 @@ def stream_chat_message(
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
tool_additional_headers=tool_additional_headers,
is_connected=is_connected,
)
for obj in objects:
Expand Down
16 changes: 16 additions & 0 deletions backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,19 @@
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)


# List of headers to pass through to tool calls (e.g., API requests made by tools)
# This allows for dynamic configuration of tool behavior based on incoming request headers
TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Consider adding a comment explaining the purpose of TOOL_PASS_THROUGH_HEADERS

_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get("TOOL_PASS_THROUGH_HEADERS")
if _TOOL_PASS_THROUGH_HEADERS_RAW:
try:
TOOL_PASS_THROUGH_HEADERS = json.loads(_TOOL_PASS_THROUGH_HEADERS_RAW)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Ensure TOOL_PASS_THROUGH_HEADERS is a list of strings after parsing

except Exception:
from danswer.utils.logger import setup_logger
pablonyx marked this conversation as resolved.
Show resolved Hide resolved

logger = setup_logger()
logger.error(
"Failed to parse TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
)
22 changes: 0 additions & 22 deletions backend/danswer/llm/headers.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,4 @@
from fastapi.datastructures import Headers

from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS


def get_litellm_additional_request_headers(
headers: dict[str, str] | Headers
) -> dict[str, str]:
if not LITELLM_PASS_THROUGH_HEADERS:
return {}

pass_through_headers: dict[str, str] = {}
for key in LITELLM_PASS_THROUGH_HEADERS:
if key in headers:
pass_through_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
pass_through_headers[lowercase_key] = headers[lowercase_key]

return pass_through_headers


def build_llm_extra_headers(
Expand Down
15 changes: 11 additions & 4 deletions backend/danswer/server/query_and_chat/chat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@

from danswer.auth.users import current_user
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import extract_headers
from danswer.chat.process_message import stream_chat_message
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from danswer.configs.model_configs import TOOL_PASS_THROUGH_HEADERS
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
Expand Down Expand Up @@ -50,7 +53,6 @@
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.headers import get_litellm_additional_request_headers
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
Expand Down Expand Up @@ -229,7 +231,9 @@ def rename_chat_session(

try:
llm, _ = get_default_llms(
additional_headers=get_litellm_additional_request_headers(request.headers)
additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
)
)
except GenAIDisabledException:
# This may be longer than what the LLM tends to produce but is the most
Expand Down Expand Up @@ -330,8 +334,11 @@ def stream_generator() -> Generator[str, None, None]:
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=get_litellm_additional_request_headers(
request.headers
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
tool_additional_headers=extract_headers(
request.headers, TOOL_PASS_THROUGH_HEADERS
),
is_connected=is_disconnected_func,
):
Expand Down
13 changes: 7 additions & 6 deletions backend/danswer/tools/custom/custom_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,17 @@ def __init__(
method_spec: MethodSpec,
base_url: str,
custom_headers: list[dict[str, str]] | None = [],
tool_additional_headers: dict[str, str] | None = None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Consider adding type hints for the tool_additional_headers parameter

) -> None:
self._base_url = base_url
self._method_spec = method_spec
self._tool_definition = self._method_spec.to_tool_definition()

self._name = self._method_spec.name
self._description = self._method_spec.summary
self.headers = (
{header["key"]: header["value"] for header in custom_headers}
if custom_headers
else {}
)
self.headers = {
header["key"]: header["value"] for header in (custom_headers or [])
} | (tool_additional_headers or {})

@property
def name(self) -> str:
Expand Down Expand Up @@ -185,6 +184,7 @@ def final_result(self, *args: ToolResponse) -> JSON_ro:

def build_custom_tools_from_openapi_schema_and_headers(
openapi_schema: dict[str, Any],
tool_additional_headers: dict[str, str] | None = None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Add a docstring explaining the purpose of tool_additional_headers

custom_headers: list[dict[str, str]] | None = [],
dynamic_schema_info: DynamicSchemaInfo | None = None,
) -> list[CustomTool]:
Expand All @@ -205,7 +205,8 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
CustomTool(method_spec, url, custom_headers, tool_additional_headers)
for method_spec in method_specs
]


Expand Down