Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 committed Nov 29, 2024
1 parent 0553062 commit fd2d607
Show file tree
Hide file tree
Showing 37 changed files with 698 additions and 1,111 deletions.
29 changes: 29 additions & 0 deletions backend/alembic/versions/9f696734098f_combine_search_and_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Combine Search and Chat
Revision ID: 9f696734098f
Revises: 6d562f86c78b
Create Date: 2024-11-27 15:32:19.694972
"""
from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = "9f696734098f"
down_revision = "6d562f86c78b"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.alter_column("chat_session", "description", nullable=True)
op.drop_column("chat_session", "one_shot")


def downgrade() -> None:
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
op.alter_column("chat_session", "description", nullable=False)
op.add_column(
"chat_session",
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
)
166 changes: 166 additions & 0 deletions backend/danswer/chat/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,79 @@
from typing import cast
from uuid import UUID

from fastapi import HTTPException
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session

from danswer.auth.users import is_user_admin
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.chat.models import PersonaOverrideConfig
from danswer.chat.models import ThreadMessage
from danswer.configs.constants import DEFAULT_PERSONA_ID
from danswer.configs.constants import MessageType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.db.chat import create_chat_session
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.llm import fetch_existing_doc_sets
from danswer.db.llm import fetch_existing_tools
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.llm.answering.models import PreviousMessage
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.utils.logger import setup_logger

logger = setup_logger()


def prepare_chat_message_request(
message_text: str,
user: User | None,
persona_id: int | None,
# Does the question need to have a persona override
persona_override_config: PersonaOverrideConfig | None,
prompt: Prompt | None,
message_ts_to_respond_to: str | None,
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
db_session=db_session,
description=None,
user_id=user.id if user else None,
# If using an override, this id will be ignored later on
persona_id=persona_id or DEFAULT_PERSONA_ID,
danswerbot_flow=True,
slack_thread_id=message_ts_to_respond_to,
)

return CreateChatMessageRequest(
chat_session_id=new_chat_session.id,
parent_message_id=None, # It's a standalone chat session each time
message=message_text,
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
prompt_id=prompt.id if prompt else None,
# Can always override the persona for the single query, if it's a normal persona
# then it will be treated the same
persona_override_config=persona_override_config,
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
)


def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inference_section.center_chunk.document_id,
Expand All @@ -34,6 +93,45 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
)


def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,
llm_tokenizer: BaseTokenizer,
) -> str:
"""Used to create a single combined message context from threads"""
if not messages:
return ""

message_strs: list[str] = []
total_token_count = 0

for message in reversed(messages):
if message.role == MessageType.USER:
role_str = message.role.value.upper()
if message.sender:
role_str += " " + message.sender
else:
# Since other messages might have the user identifying information
# better to use Unknown for symmetry
role_str += " Unknown"
else:
role_str = message.role.value.upper()

msg_str = f"{role_str}:\n{message.message}"
message_token_count = len(llm_tokenizer.encode(msg_str))

if (
max_tokens is not None
and total_token_count + message_token_count > max_tokens
):
break

message_strs.insert(0, msg_str)
total_token_count += message_token_count

return "\n\n".join(message_strs)


def create_chat_chain(
chat_session_id: UUID,
db_session: Session,
Expand Down Expand Up @@ -196,3 +294,71 @@ def extract_headers(
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers


def create_temporary_persona(
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)

"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)

if persona_config.prompts:
persona.prompts = [
Prompt(
name=p.name,
description=p.description,
system_prompt=p.system_prompt,
task_prompt=p.task_prompt,
include_citations=p.include_citations,
datetime_aware=p.datetime_aware,
)
for p in persona_config.prompts
]
elif persona_config.prompt_ids:
persona.prompts = get_prompts_by_ids(
db_session=db_session, prompt_ids=persona_config.prompt_ids
)

persona.tools = []
if persona_config.custom_tools_openapi:
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(schema),
)
persona.tools.extend(tools)

if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)

if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)

fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs

return persona
52 changes: 52 additions & 0 deletions backend/danswer/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from typing import Any

from pydantic import BaseModel
from pydantic import Field

from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import RetrievalDocs
from danswer.context.search.models import SearchResponse
Expand Down Expand Up @@ -156,6 +159,22 @@ class QAResponse(SearchResponse, DanswerAnswer):
error_msg: str | None = None


class ThreadMessage(BaseModel):
message: str
sender: str | None = None
role: MessageType = MessageType.USER


class ChatDanswerBotResponse(BaseModel):
answer: str | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run


class FileChatDisplay(BaseModel):
file_ids: list[str]

Expand All @@ -165,6 +184,39 @@ class CustomToolResponse(BaseModel):
tool_name: str


class ToolConfig(BaseModel):
id: int


class PromptOverrideConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True


class PersonaOverrideConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None

prompts: list[PromptOverrideConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)

document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)


AnswerQuestionPossibleReturn = (
DanswerAnswerPiece
| DanswerQuotes
Expand Down
Loading

0 comments on commit fd2d607

Please sign in to comment.