ayyuce commited on
Commit
aea1ea5
·
verified ·
1 Parent(s): d96d4f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -0
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from transformers import AutoProcessor, AutoModel
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain.chains import RetrievalQA
8
+ from langchain_community.llms import HuggingFacePipeline
9
+ from PIL import Image
10
+ from langchain_community.document_loaders import PyPDFLoader
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+
13
+ class MultimodalRAG:
14
+ def __init__(self, pdf_path=None):
15
+ self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
+ self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
17
+ self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
18
+ self.pdf_path = pdf_path
19
+ self.documents = []
20
+ self.vector_store = None
21
+ self.retriever = None
22
+ self.qa_chain = None
23
+
24
+ try:
25
+ self.llm = HuggingFacePipeline.from_model_id(
26
+ model_id="google/flan-t5-large",
27
+ task="text2text-generation",
28
+ model_kwargs={"temperature": 0.7, "max_length": 512}
29
+ )
30
+ except Exception as e:
31
+ print(f"Error loading flan-t5 model: {e}")
32
+ from langchain.llms import OpenAI
33
+ self.llm = OpenAI(temperature=0.7)
34
+
35
+ if pdf_path and os.path.exists(pdf_path):
36
+ self.load_pdf(pdf_path)
37
+
38
+ def load_pdf(self, pdf_path):
39
+ if not os.path.exists(pdf_path):
40
+ raise FileNotFoundError(f"PDF file not found: {pdf_path}")
41
+
42
+ loader = PyPDFLoader(pdf_path)
43
+ self.documents = loader.load()
44
+
45
+ text_splitter = RecursiveCharacterTextSplitter(
46
+ chunk_size=1000,
47
+ chunk_overlap=200
48
+ )
49
+ self.documents = text_splitter.split_documents(self.documents)
50
+
51
+ self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings)
52
+
53
+ self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2})
54
+
55
+ self.qa_chain = RetrievalQA.from_chain_type(
56
+ llm=self.llm,
57
+ chain_type="stuff",
58
+ retriever=self.retriever,
59
+ return_source_documents=True
60
+ )
61
+
62
+ return f"Successfully loaded and processed PDF: {pdf_path}"
63
+
64
+ def process_image(self, image_path):
65
+ if not os.path.exists(image_path):
66
+ print(f"Warning: Image path {image_path} does not exist")
67
+ return None
68
+
69
+ image = Image.open(image_path)
70
+ inputs = self.processor(images=image, return_tensors="pt")
71
+ with torch.no_grad():
72
+ image_features = self.vision_model.get_image_features(**inputs)
73
+ return image_features
74
+
75
+ def generate_image_description(self, image_features):
76
+ return "a photo"
77
+
78
+ def retrieve_related_documents(self, query_text, image_path=None):
79
+ if image_path:
80
+ image_features = self.process_image(image_path)
81
+
82
+ if image_features is not None:
83
+ image_query = self.generate_image_description(image_features)
84
+
85
+ enhanced_query = f"{query_text} {image_query}"
86
+ else:
87
+ enhanced_query = query_text
88
+ else:
89
+ enhanced_query = query_text
90
+
91
+ docs = self.retriever.get_relevant_documents(enhanced_query)
92
+ return docs
93
+
94
+ def answer_query(self, query_text, image_path=None):
95
+ if not self.vector_store or not self.qa_chain:
96
+ return "Please upload a PDF document first."
97
+
98
+ if image_path:
99
+ docs = self.retrieve_related_documents(query_text, image_path)
100
+ else:
101
+ docs = self.retrieve_related_documents(query_text)
102
+
103
+ result = self.qa_chain({"query": query_text})
104
+
105
+ answer = result["result"]
106
+ sources = [doc.page_content[:1000] + "..." for doc in result["source_documents"]]
107
+
108
+ return answer, sources
109
+
110
+ rag_system = MultimodalRAG()
111
+
112
+ def upload_pdf(pdf_file):
113
+ if pdf_file is None:
114
+ return "No file uploaded"
115
+
116
+ file_path = pdf_file.name
117
+ try:
118
+ result = rag_system.load_pdf(file_path)
119
+ return result
120
+ except Exception as e:
121
+ return f"Error processing PDF: {str(e)}"
122
+
123
+ def save_image(image):
124
+ if image is None:
125
+ return None
126
+
127
+ temp_path = "temp_image.jpg"
128
+ image.save(temp_path)
129
+ return temp_path
130
+
131
+ def process_query(query, pdf_file, image=None):
132
+ if not query.strip():
133
+ return "Please enter a question", []
134
+
135
+ if pdf_file is None:
136
+ return "Please upload a PDF document first", []
137
+
138
+ image_path = None
139
+ if image is not None:
140
+ image_path = save_image(image)
141
+
142
+ try:
143
+ answer, sources = rag_system.answer_query(query, image_path)
144
+ if image_path and os.path.exists(image_path):
145
+ os.remove(image_path)
146
+ return answer, sources
147
+ except Exception as e:
148
+ if image_path and os.path.exists(image_path):
149
+ os.remove(image_path)
150
+ return f"Error processing query: {str(e)}", []
151
+
152
+ # Create Gradio interface
153
+ with gr.Blocks(title="Multimodal RAG System") as demo:
154
+ gr.Markdown("# Multimodal RAG System")
155
+ gr.Markdown("Upload a PDF document and ask questions about it. You can also add an image for multimodal context.")
156
+
157
+ with gr.Row():
158
+ with gr.Column(scale=1):
159
+ pdf_input = gr.File(label="Upload PDF Document")
160
+ upload_button = gr.Button("Process PDF")
161
+ status_output = gr.Textbox(label="Status")
162
+
163
+ upload_button.click(
164
+ fn=upload_pdf,
165
+ inputs=[pdf_input],
166
+ outputs=[status_output]
167
+ )
168
+
169
+ with gr.Column(scale=2):
170
+ image_input = gr.Image(label="Optional: Upload an Image", type="pil")
171
+ query_input = gr.Textbox(label="Ask a question")
172
+ submit_button = gr.Button("Submit Question")
173
+
174
+ answer_output = gr.Textbox(label="Answer")
175
+ sources_output = gr.JSON(label="Sources")
176
+
177
+ submit_button.click(
178
+ fn=process_query,
179
+ inputs=[query_input, pdf_input, image_input],
180
+ outputs=[answer_output, sources_output]
181
+ )
182
+
183
+ if __name__ == "__main__":
184
+ demo.launch(share=True, server_name="0.0.0.0")