From e02165f9366be5ca9a42a9dd1457975ca78e672f Mon Sep 17 00:00:00 2001 From: Yassine Souissi <74144843+yassinsws@users.noreply.github.com> Date: Fri, 21 Jun 2024 12:05:58 +0200 Subject: [PATCH] Feature/basic retrieval (#122) --- app/pipeline/chat/tutor_chat_pipeline.py | 11 +++++++++ app/retrieval/lecture_retrieval.py | 29 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/app/pipeline/chat/tutor_chat_pipeline.py b/app/pipeline/chat/tutor_chat_pipeline.py index 2e4863fc..e7c943ee 100644 --- a/app/pipeline/chat/tutor_chat_pipeline.py +++ b/app/pipeline/chat/tutor_chat_pipeline.py @@ -233,6 +233,17 @@ def _run_tutor_chat_pipeline( ) self._add_relevant_chunks_to_prompt(retrieved_lecture_chunks) + retrieved_lecture_chunks = self.retriever( + chat_history=history, + student_query=query.contents[0].text_content, + result_limit=10, + course_name=dto.course.name, + course_id=dto.course.id, + problem_statement=problem_statement, + exercise_title=exercise_title, + ) + self._add_relevant_chunks_to_prompt(retrieved_lecture_chunks) + self.callback.in_progress("Generating response...") # Add the final message to the prompt and run the pipeline diff --git a/app/retrieval/lecture_retrieval.py b/app/retrieval/lecture_retrieval.py index c68c5442..b64a746d 100644 --- a/app/retrieval/lecture_retrieval.py +++ b/app/retrieval/lecture_retrieval.py @@ -143,6 +143,35 @@ def __call__( return [merged_chunks[int(i)] for i in selected_chunks_index] return [] + def basic_lecture_retrieval( + self, + chat_history: list[PyrisMessage], + student_query: str, + result_limit: int, + course_name: str = None, + course_id: int = None, + base_url: str = None, + ) -> list[dict[str, dict]]: + """ + Basic retrieval for pipelines thaat need performance and fast answers. + """ + rewritten_query = self.rewrite_student_query( + chat_history, student_query, "course_language", course_name + ) + response = self.search_in_db( + query=rewritten_query, + hybrid_factor=0.9, + result_limit=result_limit, + course_id=course_id, + base_url=base_url, + ) + + basic_retrieved_lecture_chunks: list[dict[str, dict]] = [ + {"id": obj.uuid.int, "properties": obj.properties} + for obj in response.objects + ] + return basic_retrieved_lecture_chunks + def rewrite_student_query( self, chat_history: list[PyrisMessage],