-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathhybridColpaliDemo.py
160 lines (133 loc) · 5.14 KB
/
hybridColpaliDemo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import os
import lancedb
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
from typing import List
from PIL import Image
import base64
import io
import argparse
# Import the HybridColpaliRAG class and OpenAI VLM
from varag.rag import HybridColpaliRAG
from varag.vlms import OpenAI
from varag.utils import get_model_colpali
model, processor = get_model_colpali("vidore/colpali-v1.2")
load_dotenv()
# Initialize shared database
shared_db = lancedb.connect("~/shared_rag_db")
# Initialize HybridColpaliRAG
embedding_model = SentenceTransformer("jinaai/jina-clip-v1", trust_remote_code=True)
hybrid_rag = HybridColpaliRAG(
colpali_model=model,
colpali_processor=processor,
image_embedding_model=embedding_model,
db=shared_db,
table_name="hybridColpaliDemo",
)
# Initialize VLM
vlm = OpenAI()
def ingest_pdfs(pdf_files, table_name, recursive, verbose):
try:
if table_name:
hybrid_rag.change_table(table_name)
file_paths = [pdf_file.name for pdf_file in pdf_files]
hybrid_rag.index(
file_paths, overwrite=False, recursive=recursive, verbose=verbose
)
return f"PDFs ingested successfully into table '{hybrid_rag.table_name}'."
except Exception as e:
return f"Error ingesting PDFs: {str(e)}"
def search_and_analyze(query, table_name, use_image_search, top_k):
try:
if table_name:
hybrid_rag.change_table(table_name)
results = hybrid_rag.search(
query,
k=top_k,
use_image_search=use_image_search,
)
pil_images = []
for result in results:
image_data = result["image"]
if isinstance(image_data, Image.Image):
pil_images.append(image_data)
elif isinstance(image_data, str):
# Assume it's base64 encoded
pil_images.append(Image.open(io.BytesIO(base64.b64decode(image_data))))
elif isinstance(image_data, bytes):
pil_images.append(Image.open(io.BytesIO(image_data)))
else:
raise ValueError(f"Unexpected image type: {type(image_data)}")
# Prepare context for VLM
context = f"Query: {query}\n\nRelevant image information:\n"
for i, result in enumerate(results, 1):
context += f"Image {i}: From document '{result['name']}', page {result['page_number']}\n"
if "metadata" in result:
context += f"Metadata: {result['metadata']}\n"
if "page_text" in result:
context += f"Page text: {result['page_text'][:500]}...\n\n"
# Generate response using VLM
vlm_response = vlm.query(context, pil_images, max_tokens=500)
return vlm_response, pil_images
except Exception as e:
return f"Error generating response: {str(e)}", []
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# HybridColpaliRAG Image Search and Analysis with VLM")
with gr.Tab("Ingest PDFs"):
pdf_input = gr.File(
label="Upload PDF(s)", file_count="multiple", file_types=["pdf"]
)
table_name_input = gr.Textbox(
label="Table Name (optional)", placeholder="default_table"
)
recursive_checkbox = gr.Checkbox(label="Recursive Indexing", value=False)
verbose_checkbox = gr.Checkbox(label="Verbose Output", value=True)
ingest_button = gr.Button("Ingest PDFs")
ingest_output = gr.Textbox(label="Ingestion Status")
with gr.Tab("Search and Analyze"):
query_input = gr.Textbox(label="Enter your query")
search_table_name_input = gr.Textbox(
label="Table Name (optional)", placeholder="default_table"
)
use_image_search = gr.Checkbox(label="Use Image Search", value=True)
top_k_slider = gr.Slider(
minimum=1, maximum=10, value=3, step=1, label="Top K Results"
)
search_button = gr.Button("Search and Analyze")
response_output = gr.Textbox(label="VLM Response")
image_output = gr.Gallery(label="Retrieved Images")
ingest_button.click(
ingest_pdfs,
inputs=[
pdf_input,
table_name_input,
recursive_checkbox,
verbose_checkbox,
],
outputs=ingest_output,
)
search_button.click(
search_and_analyze,
inputs=[
query_input,
search_table_name_input,
use_image_search,
top_k_slider,
],
outputs=[response_output, image_output],
)
return demo
# Parse command-line arguments
def parse_args():
parser = argparse.ArgumentParser(description="TextRAG Gradio App")
parser.add_argument(
"--share", action="store_true", help="Enable Gradio share feature"
)
return parser.parse_args()
# Main execution
if __name__ == "__main__":
args = parse_args()
demo = create_gradio_interface()
demo.launch(share=args.share)