Skip to content

Commit

Permalink
k
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 committed Nov 29, 2024
1 parent 09dfe3e commit c92081f
Show file tree
Hide file tree
Showing 12 changed files with 29 additions and 5 deletions.
3 changes: 3 additions & 0 deletions backend/danswer/chat/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from danswer.configs.constants import DEFAULT_PERSONA_ID
from danswer.configs.constants import MessageType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.db.chat import create_chat_session
from danswer.db.chat import get_chat_messages_by_session
Expand Down Expand Up @@ -45,6 +46,7 @@ def prepare_chat_message_request(
prompt: Prompt | None,
message_ts_to_respond_to: str | None,
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
Expand All @@ -69,6 +71,7 @@ def prepare_chat_message_request(
persona_override_config=persona_override_config,
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
)


Expand Down
4 changes: 3 additions & 1 deletion backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def stream_chat_message_objects(
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
bypass_acl: bool = False,
include_contexts: bool = False,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
Expand Down Expand Up @@ -624,6 +625,7 @@ def stream_chat_message_objects(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
rerank_settings=new_msg_req.rerank_settings,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
Expand Down Expand Up @@ -774,7 +776,7 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID:
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(DanswerContexts, packet.response)

elif isinstance(packet, StreamStopInfo):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def _get_slack_answer(
prompt=prompt,
message_ts_to_respond_to=message_ts_to_respond_to,
retrieval_details=retrieval_details,
rerank_settings=None, # Rerank customization supported in Slack flow
db_session=db_session,
)

Expand Down
1 change: 1 addition & 0 deletions backend/danswer/server/openai_assistants_api/runs_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def process_run_in_background(
prompt_id=chat_session.persona.prompts[0].id,
search_doc_ids=None,
retrieval_options=search_tool_retrieval_details, # Adjust as needed
rerank_settings=None,
query_override=None,
regenerate=None,
llm_override=None,
Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/server/query_and_chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from danswer.configs.constants import SearchFeedbackType
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import ChunkContext
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.models import SearchDoc
from danswer.context.search.models import Tag
Expand Down Expand Up @@ -87,6 +88,7 @@ class CreateChatMessageRequest(ChunkContext):
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None
retrieval_options: RetrievalDetails | None
rerank_settings: RerankingDetails | None
# allows the caller to specify the exact search query they want to use
# will disable Query Rewording if specified
query_override: str | None = None
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/tools/tool_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import Persona
Expand Down Expand Up @@ -102,6 +103,7 @@ class SearchToolConfig(BaseModel):
default_factory=DocumentPruningConfig
)
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
selected_sections: list[InferenceSection] | None = None
chunks_above: int = 0
chunks_below: int = 0
Expand Down Expand Up @@ -172,6 +174,7 @@ def construct_tools(
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
tool_dict[db_tool_model.id] = [search_tool]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from danswer.context.search.enums import SearchType
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
chunks_below: int | None = None,
full_doc: bool = False,
bypass_acl: bool = False,
rerank_settings: RerankingDetails | None = None,
) -> None:
self.user = user
self.persona = persona
Expand All @@ -118,6 +120,9 @@ def __init__(
self.bypass_acl = bypass_acl
self.db_session = db_session

# Only used via API
self.rerank_settings = rerank_settings

self.chunks_above = (
chunks_above
if chunks_above is not None
Expand Down Expand Up @@ -292,6 +297,7 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
self.retrieval_options.offset if self.retrieval_options else None
),
limit=self.retrieval_options.limit if self.retrieval_options else None,
rerank_settings=self.rerank_settings,
chunks_above=self.chunks_above,
chunks_below=self.chunks_below,
full_doc=self.full_doc,
Expand Down
1 change: 0 additions & 1 deletion backend/ee/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def gather_stream_for_answer_api(
response = OneShotQAResponse()

answer = ""
# TODO handle giving back context
for packet in packets:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
Expand Down
4 changes: 4 additions & 0 deletions backend/ee/danswer/server/query_and_chat/chat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def handle_simplified_chat_message(
prompt_id=None,
search_doc_ids=chat_message_req.search_doc_ids,
retrieval_options=retrieval_options,
# Simple API does not support reranking, hide complexity from user
rerank_settings=None,
query_override=chat_message_req.query_override,
# Currently only applies to search flow not chat
chunks_above=0,
Expand Down Expand Up @@ -292,6 +294,8 @@ def handle_send_message_simple_with_history(
prompt_id=req.prompt_id,
search_doc_ids=req.search_doc_ids,
retrieval_options=retrieval_options,
# Simple API does not support reranking, hide complexity from user
rerank_settings=None,
query_override=rephrased_query,
chunks_above=0,
chunks_below=0,
Expand Down
2 changes: 0 additions & 2 deletions backend/ee/danswer/server/query_and_chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ class OneShotQARequest(ChunkContext):
messages: list[ThreadMessage]
prompt_id: int | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
# TODO add these options into Chat
rerank_settings: RerankingDetails | None = None
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
return_contexts: bool = False

# allows the caller to specify the exact search query they want to use
Expand Down
6 changes: 5 additions & 1 deletion backend/ee/danswer/server/query_and_chat/query_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,15 @@ def get_answer_stream(
prompt=prompt,
message_ts_to_respond_to=None,
retrieval_details=query_request.retrieval_options,
rerank_settings=query_request.rerank_settings,
db_session=db_session,
)

packets = stream_chat_message_objects(
new_msg_req=request, user=user, db_session=db_session
new_msg_req=request,
user=user,
db_session=db_session,
include_contexts=query_request.return_contexts,
)

return packets
Expand Down
1 change: 1 addition & 0 deletions backend/tests/integration/common_utils/managers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def send_message(
prompt_id=prompt_id,
search_doc_ids=search_doc_ids or [],
retrieval_options=retrieval_options,
rerank_settings=None, # Can be added if needed
query_override=query_override,
regenerate=regenerate,
llm_override=llm_override,
Expand Down

0 comments on commit c92081f

Please sign in to comment.