-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathtextDemo.py
120 lines (97 loc) · 3.71 KB
/
textDemo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
from sentence_transformers import SentenceTransformer
from varag.llms import OpenAI
from varag.rag import SimpleRAG
from varag.chunking import FixedTokenChunker
import lancedb
from dotenv import load_dotenv
import argparse
load_dotenv()
# Initialize embedding model
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", trust_remote_code=True)
# embedding_model = SentenceTransformer("BAAI/bge-base-en", trust_remote_code=True)
# embedding_model = SentenceTransformer("BAAI/bge-large-en-v1.5", trust_remote_code=True)
# embedding_model = SentenceTransformer("BAAI/bge-small-en-v1.5", trust_remote_code=True)
# Initialize shared database
shared_db = lancedb.connect("~/shared_rag_db")
# Initialize TextRAG with shared database
text_rag = SimpleRAG(
text_embedding_model=embedding_model,
db=shared_db,
table_name="textDemo",
)
# Initialize OpenAI client
llm = OpenAI()
def ingest_documents(files, chunk_size, use_ocr):
pdf_paths = [file.name for file in files]
result = text_rag.index(
pdf_paths,
recursive=False,
chunking_strategy=FixedTokenChunker(chunk_size=chunk_size),
metadata={"source": "gradio_upload"},
overwrite=True,
verbose=True,
ocr=use_ocr,
)
return f"Ingestion complete. {result}"
def query_and_answer(query, num_results):
# Search for relevant chunks
search_results = text_rag.search(query, k=num_results)
# Generate response using OpenAI
context = "\n".join([r["text"] for r in search_results])
response = llm.query(
context=context,
system_prompt="Given the below information answer the questions",
query=query,
)
# Format the results
formatted_results = "\n\n".join(
[
f"{'==='*50}\n\n\nChunk {i+1}:\n{r['text']}{r['chunk_index']}{r['document_name']}\n\n\n{'==='*50}"
for i, r in enumerate(search_results)
]
)
return formatted_results, response
# Define the Gradio interface
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# TextRAG Document Ingestion and Query System")
with gr.Tab("Ingest Documents"):
file_input = gr.File(
file_count="multiple", label="Upload PDF Documents", file_types=["pdf"]
)
chunk_size = gr.Slider(50, 5000, value=200, step=10, label="Chunk Size")
use_ocr = gr.Checkbox(label="Use OCR")
ingest_button = gr.Button("Ingest Documents")
ingest_output = gr.Textbox(label="Ingestion Result")
ingest_button.click(
ingest_documents,
inputs=[file_input, chunk_size, use_ocr],
outputs=ingest_output,
)
with gr.Tab("Query Documents"):
query_input = gr.Textbox(label="Enter your query")
num_results = gr.Slider(
1, 10, value=5, step=1, label="Number of results to retrieve"
)
query_button = gr.Button("Search and Answer")
retrieved_chunks = gr.Textbox(label="Retrieved Chunks")
answer_output = gr.Textbox(label="Generated Answer")
query_button.click(
query_and_answer,
inputs=[query_input, num_results],
outputs=[retrieved_chunks, answer_output],
)
return demo
# Parse command-line arguments
def parse_args():
parser = argparse.ArgumentParser(description="TextRAG Gradio App")
parser.add_argument(
"--share", action="store_true", help="Enable Gradio share feature"
)
return parser.parse_args()
# Main execution
if __name__ == "__main__":
args = parse_args()
demo = create_gradio_interface()
demo.launch(share=args.share)