Skip to content

Commit

Permalink
Added the logic in ingest python files to use the TEI_ENDPOINT to ing…
Browse files Browse the repository at this point in the history
…est the data into the redis vector DB (#84)

Signed-off-by: Pallavi Jaini <[email protected]>
  • Loading branch information
pallavijaini0525 authored Apr 18, 2024
1 parent f0b73ef commit 2ada2c8
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 51 deletions.
2 changes: 2 additions & 0 deletions ChatQnA/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ Make sure TGI-Gaudi service is running and also make sure data is populated into

```bash
docker exec -it qna-rag-redis-server bash
# export TGI_LLM_ENDPOINT="http://xxx.xxx.xxx.xxx:8080" - can be omitted if set before in docker-compose.yml
# export TEI_ENDPOINT="http://xxx.xxx.xxx.xxx:9090" - Needs to be added only if TEI to be used and can be omitted if set before in docker-compose.yml
nohup python app/server.py &
```

Expand Down
101 changes: 73 additions & 28 deletions ChatQnA/langchain/docker/qna-app/app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

from fastapi import APIRouter, FastAPI, File, Request, UploadFile
Expand All @@ -27,7 +28,7 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langserve import add_routes
from prompts import contextualize_q_prompt, qa_prompt
from prompts import contextualize_q_prompt, prompt, qa_prompt
from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL
from starlette.middleware.cors import CORSMiddleware
from utils import (
Expand All @@ -39,6 +40,10 @@
reload_retriever,
)

parser = argparse.ArgumentParser(description="Server Configuration")
parser.add_argument("--chathistory", action="store_true", help="Enable debug mode")
args = parser.parse_args()

app = FastAPI()

app.add_middleware(
Expand Down Expand Up @@ -102,9 +107,15 @@ def __init__(self, upload_dir, entrypoint, safety_guard_endpoint, tei_endpoint=N
self.contextualize_q_chain = contextualize_q_prompt | self.llm | StrOutputParser()

# Define LLM chain
self.llm_chain = (
RunnablePassthrough.assign(context=self.contextualized_question | retriever) | qa_prompt | self.llm
)
if args.chathistory:
self.llm_chain = (
RunnablePassthrough.assign(context=self.contextualized_question | retriever) | qa_prompt | self.llm
)
else:
self.llm_chain = (
RunnablePassthrough.assign(context=self.contextualized_question | retriever) | prompt | self.llm
)

print("[rag - router] LLM chain initialized.")

# Define chat history
Expand All @@ -117,9 +128,12 @@ def contextualized_question(self, input: dict):
return input["question"]

def handle_rag_chat(self, query: str):
response = self.llm_chain.invoke({"question": query, "chat_history": self.chat_history})
response = self.llm_chain.invoke(
{"question": query, "chat_history": self.chat_history} if args.chathistory else {"question": query}
)
result = response.split("</s>")[0]
self.chat_history.extend([HumanMessage(content=query), response])
if args.chathistory:
self.chat_history.extend([HumanMessage(content=query), response])
# output guardrails
if self.safety_guard_endpoint:
response_output_guard = self.llm_guard(
Expand Down Expand Up @@ -148,7 +162,6 @@ async def rag_chat(request: Request):
print(f"[rag - chat] POST request: /v1/rag/chat, params:{params}")
query = params["query"]
kb_id = params.get("knowledge_base_id", "default")
print(f"[rag - chat] history: {router.chat_history}")

# prompt guardrails
if router.safety_guard_endpoint:
Expand All @@ -162,16 +175,26 @@ async def rag_chat(request: Request):
if kb_id == "default":
print("[rag - chat] use default knowledge base")
retriever = reload_retriever(router.embeddings, INDEX_NAME)
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
if args.chathistory:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
else:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm
)
elif kb_id.startswith("kb"):
new_index_name = INDEX_NAME + kb_id
print(f"[rag - chat] use knowledge base {kb_id}, index name is {new_index_name}")
retriever = reload_retriever(router.embeddings, new_index_name)
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
if args.chathistory:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
else:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm
)
else:
return JSONResponse(status_code=400, content={"message": "Wrong knowledge base id."})
return router.handle_rag_chat(query=query)
Expand All @@ -183,7 +206,6 @@ async def rag_chat_stream(request: Request):
print(f"[rag - chat_stream] POST request: /v1/rag/chat_stream, params:{params}")
query = params["query"]
kb_id = params.get("knowledge_base_id", "default")
print(f"[rag - chat_stream] history: {router.chat_history}")

# prompt guardrails
if router.safety_guard_endpoint:
Expand All @@ -202,28 +224,41 @@ def generate_content():

if kb_id == "default":
retriever = reload_retriever(router.embeddings, INDEX_NAME)
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
if args.chathistory:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
else:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm
)
elif kb_id.startswith("kb"):
new_index_name = INDEX_NAME + kb_id
retriever = reload_retriever(router.embeddings, new_index_name)
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
if args.chathistory:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
else:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm
)
else:
return JSONResponse(status_code=400, content={"message": "Wrong knowledge base id."})

def stream_generator():
chat_response = ""
for text in router.llm_chain.stream({"question": query, "chat_history": router.chat_history}):
for text in router.llm_chain.stream(
{"question": query, "chat_history": router.chat_history} if args.chathistory else {"question": query}
):
chat_response += text
processed_text = post_process_text(text)
if text and processed_text:
yield processed_text
chat_response = chat_response.split("</s>")[0]
print(f"[rag - chat_stream] stream response: {chat_response}")
router.chat_history.extend([HumanMessage(content=query), chat_response])
if args.chathistory:
router.chat_history.extend([HumanMessage(content=query), chat_response])
yield "data: [DONE]\n\n"

return StreamingResponse(stream_generator(), media_type="text/event-stream")
Expand Down Expand Up @@ -251,9 +286,14 @@ async def rag_create(file: UploadFile = File(...)):
print("[rag - create] starting to create local db...")
index_name = INDEX_NAME + kb_id
retriever = create_retriever_from_files(save_file_name, router.embeddings, index_name)
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
if args.chathistory:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
else:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm
)
print("[rag - create] kb created successfully")
except Exception as e:
print(f"[rag - create] create knowledge base failed! {e}")
Expand All @@ -274,9 +314,14 @@ async def rag_upload_link(request: Request):
print("[rag - upload_link] starting to create local db...")
index_name = INDEX_NAME + kb_id
retriever = create_retriever_from_links(router.embeddings, link_list, index_name)
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
if args.chathistory:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | qa_prompt | router.llm
)
else:
router.llm_chain = (
RunnablePassthrough.assign(context=router.contextualized_question | retriever) | prompt | router.llm
)
print("[rag - upload_link] kb created successfully")
except Exception as e:
print(f"[rag - upload_link] create knowledge base failed! {e}")
Expand Down
35 changes: 24 additions & 11 deletions ChatQnA/langchain/redis/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@

import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings, HuggingFaceHubEmbeddings
from langchain_community.vectorstores import Redis
from PIL import Image
from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL

tei_embedding_endpoint = os.getenv("TEI_ENDPOINT")


def pdf_loader(file_path):
try:
Expand Down Expand Up @@ -79,17 +81,28 @@ def ingest_documents():

print("Done preprocessing. Created ", len(chunks), " chunks of the original pdf")
# Create vectorstore
embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
if tei_embedding_endpoint:
# create embeddings using TEI endpoint service
embedder = HuggingFaceHubEmbeddings(model=tei_embedding_endpoint)
else:
# create embeddings using local embedding model
embedder = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL)

# Batch size
batch_size = 32
num_chunks = len(chunks)
for i in range(0, num_chunks, batch_size):
batch_chunks = chunks[i : i + batch_size]
batch_texts = [f"Company: {company_name}. " + chunk for chunk in batch_chunks]

_ = Redis.from_texts(
# appending this little bit can sometimes help with semantic retrieval
# especially with multiple companies
texts=[f"Company: {company_name}. " + chunk for chunk in chunks],
embedding=embedder,
index_name=INDEX_NAME,
index_schema=INDEX_SCHEMA,
redis_url=REDIS_URL,
)
_ = Redis.from_texts(
texts=batch_texts,
embedding=embedder,
index_name=INDEX_NAME,
index_schema=INDEX_SCHEMA,
redis_url=REDIS_URL,
)
print(f"Processed batch {i//batch_size + 1}/{(num_chunks-1)//batch_size + 1}")


if __name__ == "__main__":
Expand Down
36 changes: 25 additions & 11 deletions ChatQnA/langchain/redis/ingest_intel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@

import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings, HuggingFaceHubEmbeddings
from langchain_community.vectorstores import Redis
from PIL import Image
from rag_redis.config import EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL

tei_embedding_endpoint = os.getenv("TEI_ENDPOINT")


def pdf_loader(file_path):
try:
Expand Down Expand Up @@ -79,17 +81,29 @@ def ingest_documents():

print("Done preprocessing. Created", len(chunks), "chunks of the original pdf")
# Create vectorstore
embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
# Create vectorstore
if tei_embedding_endpoint:
# create embeddings using TEI endpoint service
embedder = HuggingFaceHubEmbeddings(model=tei_embedding_endpoint)
else:
# create embeddings using local embedding model
embedder = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL)

_ = Redis.from_texts(
# appending this little bit can sometimes help with semantic retrieval
# especially with multiple companies
texts=[f"Company: {company_name}. " + chunk for chunk in chunks],
embedding=embedder,
index_name=INDEX_NAME,
index_schema=INDEX_SCHEMA,
redis_url=REDIS_URL,
)
# Batch size
batch_size = 32
num_chunks = len(chunks)
for i in range(0, num_chunks, batch_size):
batch_chunks = chunks[i : i + batch_size]
batch_texts = [f"Company: {company_name}. " + chunk for chunk in batch_chunks]

_ = Redis.from_texts(
texts=batch_texts,
embedding=embedder,
index_name=INDEX_NAME,
index_schema=INDEX_SCHEMA,
redis_url=REDIS_URL,
)
print(f"Processed batch {i//batch_size + 1}/{(num_chunks-1)//batch_size + 1}")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion ChatQnA/ui/.env
Original file line number Diff line number Diff line change
@@ -1 +1 @@
DOC_BASE_URL = 'http://xxx.xxx.xxx.xxx:8000/v1/rag'
DOC_BASE_URL = 'http://localhost:8000/v1/rag'

0 comments on commit 2ada2c8

Please sign in to comment.