diff --git a/ChatQnA/README.md b/ChatQnA/README.md index 7490b715d..e7fbe1feb 100644 --- a/ChatQnA/README.md +++ b/ChatQnA/README.md @@ -166,6 +166,8 @@ Make sure TGI-Gaudi service is running and also make sure data is populated into ```bash docker exec -it qna-rag-redis-server bash +# export TGI_LLM_ENDPOINT="http://xxx.xxx.xxx.xxx:8080" - can be omitted if set before in docker-compose.yml +# export TEI_ENDPOINT="http://xxx.xxx.xxx.xxx:9090" - Needs to be added only if TEI to be used and can be omitted if set before in docker-compose.yml nohup python app/server.py & ``` diff --git a/ChatQnA/langchain/docker/qna-app/app/server.py b/ChatQnA/langchain/docker/qna-app/app/server.py index 0e282d2dc..fff024077 100644 --- a/ChatQnA/langchain/docker/qna-app/app/server.py +++ b/ChatQnA/langchain/docker/qna-app/app/server.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os from fastapi import APIRouter, FastAPI, File, Request, UploadFile @@ -27,7 +28,7 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langserve import add_routes -from prompts import contextualize_q_prompt, qa_prompt +from prompts import contextualize_q_prompt, prompt, qa_prompt from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL from starlette.middleware.cors import CORSMiddleware from utils import ( @@ -39,6 +40,10 @@ reload_retriever, ) +parser = argparse.ArgumentParser(description="Server Configuration") +parser.add_argument("--chathistory", action="store_true", help="Enable debug mode") +args = parser.parse_args() + app = FastAPI() app.add_middleware( @@ -102,9 +107,15 @@ def __init__(self, upload_dir, entrypoint, safety_guard_endpoint, tei_endpoint=N self.contextualize_q_chain = contextualize_q_prompt | self.llm | StrOutputParser() # Define LLM chain - self.llm_chain = ( - RunnablePassthrough.assign(context=self.contextualized_question | retriever) | qa_prompt | self.llm - ) + if args.chathistory: + self.llm_chain = ( + RunnablePassthrough.assign(context=self.contextualized_question | retriever) | qa_prompt | self.llm + ) + else: + self.llm_chain = ( + RunnablePassthrough.assign(context=self.contextualized_question | retriever) | prompt | self.llm + ) + print("[rag - router] LLM chain initialized.") # Define chat history @@ -117,9 +128,12 @@ def contextualized_question(self, input: dict): return input["question"] def handle_rag_chat(self, query: str): - response = self.llm_chain.invoke({"question": query, "chat_history": self.chat_history}) + response = self.llm_chain.invoke( + {"question": query, "chat_history": self.chat_history} if args.chathistory else {"question": query} + ) result = response.split("")[0] - self.chat_history.extend([HumanMessage(content=query), response]) + if args.chathistory: + self.chat_history.extend([HumanMessage(content=query), response]) # output guardrails if self.safety_guard_endpoint: response_output_guard = self.llm_guard( @@ -148,7 +162,6 @@ async def rag_chat(request: Request): print(f"[rag - chat] POST request: /v1/rag/chat, params:{params}") query = params["query"] kb_id = params.get("knowledge_base_id", "default") - print(f"[rag - chat] history: {router.chat_history}") # prompt guardrails if router.safety_guard_endpoint: @@ -162,16 +175,26 @@ async def rag_chat(request: Request): if kb_id == "default": print("[rag - chat] use default knowledge base") retriever = reload_retriever(router.embeddings, INDEX_NAME) - router.llm_chain = ( - RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm - ) + if args.chathistory: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm + ) + else: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm + ) elif kb_id.startswith("kb"): new_index_name = INDEX_NAME + kb_id print(f"[rag - chat] use knowledge base {kb_id}, index name is {new_index_name}") retriever = reload_retriever(router.embeddings, new_index_name) - router.llm_chain = ( - RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm - ) + if args.chathistory: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm + ) + else: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm + ) else: return JSONResponse(status_code=400, content={"message": "Wrong knowledge base id."}) return router.handle_rag_chat(query=query) @@ -183,7 +206,6 @@ async def rag_chat_stream(request: Request): print(f"[rag - chat_stream] POST request: /v1/rag/chat_stream, params:{params}") query = params["query"] kb_id = params.get("knowledge_base_id", "default") - print(f"[rag - chat_stream] history: {router.chat_history}") # prompt guardrails if router.safety_guard_endpoint: @@ -202,28 +224,41 @@ def generate_content(): if kb_id == "default": retriever = reload_retriever(router.embeddings, INDEX_NAME) - router.llm_chain = ( - RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm - ) + if args.chathistory: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm + ) + else: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm + ) elif kb_id.startswith("kb"): new_index_name = INDEX_NAME + kb_id retriever = reload_retriever(router.embeddings, new_index_name) - router.llm_chain = ( - RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm - ) + if args.chathistory: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm + ) + else: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm + ) else: return JSONResponse(status_code=400, content={"message": "Wrong knowledge base id."}) def stream_generator(): chat_response = "" - for text in router.llm_chain.stream({"question": query, "chat_history": router.chat_history}): + for text in router.llm_chain.stream( + {"question": query, "chat_history": router.chat_history} if args.chathistory else {"question": query} + ): chat_response += text processed_text = post_process_text(text) if text and processed_text: yield processed_text chat_response = chat_response.split("")[0] print(f"[rag - chat_stream] stream response: {chat_response}") - router.chat_history.extend([HumanMessage(content=query), chat_response]) + if args.chathistory: + router.chat_history.extend([HumanMessage(content=query), chat_response]) yield "data: [DONE]\n\n" return StreamingResponse(stream_generator(), media_type="text/event-stream") @@ -251,9 +286,14 @@ async def rag_create(file: UploadFile = File(...)): print("[rag - create] starting to create local db...") index_name = INDEX_NAME + kb_id retriever = create_retriever_from_files(save_file_name, router.embeddings, index_name) - router.llm_chain = ( - RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm - ) + if args.chathistory: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm + ) + else: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm + ) print("[rag - create] kb created successfully") except Exception as e: print(f"[rag - create] create knowledge base failed! {e}") @@ -274,9 +314,14 @@ async def rag_upload_link(request: Request): print("[rag - upload_link] starting to create local db...") index_name = INDEX_NAME + kb_id retriever = create_retriever_from_links(router.embeddings, link_list, index_name) - router.llm_chain = ( - RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm - ) + if args.chathistory: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm + ) + else: + router.llm_chain = ( + RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm + ) print("[rag - upload_link] kb created successfully") except Exception as e: print(f"[rag - upload_link] create knowledge base failed! {e}") diff --git a/ChatQnA/langchain/redis/ingest.py b/ChatQnA/langchain/redis/ingest.py index 6d0ad0141..2377d59be 100644 --- a/ChatQnA/langchain/redis/ingest.py +++ b/ChatQnA/langchain/redis/ingest.py @@ -20,11 +20,13 @@ import numpy as np from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings, HuggingFaceHubEmbeddings from langchain_community.vectorstores import Redis from PIL import Image from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL +tei_embedding_endpoint = os.getenv("TEI_ENDPOINT") + def pdf_loader(file_path): try: @@ -79,17 +81,28 @@ def ingest_documents(): print("Done preprocessing. Created ", len(chunks), " chunks of the original pdf") # Create vectorstore - embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL) + if tei_embedding_endpoint: + # create embeddings using TEI endpoint service + embedder = HuggingFaceHubEmbeddings(model=tei_embedding_endpoint) + else: + # create embeddings using local embedding model + embedder = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL) + + # Batch size + batch_size = 32 + num_chunks = len(chunks) + for i in range(0, num_chunks, batch_size): + batch_chunks = chunks[i : i + batch_size] + batch_texts = [f"Company: {company_name}. " + chunk for chunk in batch_chunks] - _ = Redis.from_texts( - # appending this little bit can sometimes help with semantic retrieval - # especially with multiple companies - texts=[f"Company: {company_name}. " + chunk for chunk in chunks], - embedding=embedder, - index_name=INDEX_NAME, - index_schema=INDEX_SCHEMA, - redis_url=REDIS_URL, - ) + _ = Redis.from_texts( + texts=batch_texts, + embedding=embedder, + index_name=INDEX_NAME, + index_schema=INDEX_SCHEMA, + redis_url=REDIS_URL, + ) + print(f"Processed batch {i//batch_size + 1}/{(num_chunks-1)//batch_size + 1}") if __name__ == "__main__": diff --git a/ChatQnA/langchain/redis/ingest_intel.py b/ChatQnA/langchain/redis/ingest_intel.py index 5d266307f..e486e277f 100644 --- a/ChatQnA/langchain/redis/ingest_intel.py +++ b/ChatQnA/langchain/redis/ingest_intel.py @@ -20,11 +20,13 @@ import numpy as np from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings, HuggingFaceHubEmbeddings from langchain_community.vectorstores import Redis from PIL import Image from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL +tei_embedding_endpoint = os.getenv("TEI_ENDPOINT") + def pdf_loader(file_path): try: @@ -79,17 +81,29 @@ def ingest_documents(): print("Done preprocessing. Created", len(chunks), "chunks of the original pdf") # Create vectorstore - embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL) + # Create vectorstore + if tei_embedding_endpoint: + # create embeddings using TEI endpoint service + embedder = HuggingFaceHubEmbeddings(model=tei_embedding_endpoint) + else: + # create embeddings using local embedding model + embedder = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL) - _ = Redis.from_texts( - # appending this little bit can sometimes help with semantic retrieval - # especially with multiple companies - texts=[f"Company: {company_name}. " + chunk for chunk in chunks], - embedding=embedder, - index_name=INDEX_NAME, - index_schema=INDEX_SCHEMA, - redis_url=REDIS_URL, - ) + # Batch size + batch_size = 32 + num_chunks = len(chunks) + for i in range(0, num_chunks, batch_size): + batch_chunks = chunks[i : i + batch_size] + batch_texts = [f"Company: {company_name}. " + chunk for chunk in batch_chunks] + + _ = Redis.from_texts( + texts=batch_texts, + embedding=embedder, + index_name=INDEX_NAME, + index_schema=INDEX_SCHEMA, + redis_url=REDIS_URL, + ) + print(f"Processed batch {i//batch_size + 1}/{(num_chunks-1)//batch_size + 1}") if __name__ == "__main__": diff --git a/ChatQnA/ui/.env b/ChatQnA/ui/.env index 3ed60bae1..48bf77994 100644 --- a/ChatQnA/ui/.env +++ b/ChatQnA/ui/.env @@ -1 +1 @@ -DOC_BASE_URL = 'http://xxx.xxx.xxx.xxx:8000/v1/rag' \ No newline at end of file +DOC_BASE_URL = 'http://localhost:8000/v1/rag' \ No newline at end of file