ayyuce commited on
Commit
eb63e15
·
verified ·
1 Parent(s): ae8cf80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -52
app.py CHANGED
@@ -1,100 +1,184 @@
 
1
  import os
2
  import torch
3
- from PIL import Image
4
- import gradio as gr
5
  from transformers import AutoProcessor, AutoModel
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_community.vectorstores import FAISS
8
  from langchain.chains import RetrievalQA
9
  from langchain_community.llms import HuggingFacePipeline
 
10
  from langchain_community.document_loaders import PyPDFLoader
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
 
13
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
14
-
15
  class MultimodalRAG:
16
- def __init__(self, pdf_path):
17
  self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
18
  self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
19
  self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
20
-
21
- self.documents = self._load_and_split(pdf_path)
22
- self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings)
23
-
 
 
24
  try:
25
  self.llm = HuggingFacePipeline.from_model_id(
26
  model_id="google/flan-t5-large",
27
  task="text2text-generation",
28
- model_kwargs={"temperature": 0.7, "max_length": 512, "device": -1}
29
  )
30
- except Exception:
 
31
  from langchain.llms import OpenAI
32
  self.llm = OpenAI(temperature=0.7)
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
 
 
 
 
 
34
  self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2})
 
35
  self.qa_chain = RetrievalQA.from_chain_type(
36
  llm=self.llm,
37
  chain_type="stuff",
38
  retriever=self.retriever,
39
  return_source_documents=True
40
  )
 
 
41
 
42
- def _load_and_split(self, pdf_path):
43
- loader = PyPDFLoader(pdf_path)
44
- docs = loader.load()
45
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
46
- return splitter.split_documents(docs)
47
 
48
- def _get_image_features(self, image_path):
49
- image = Image.open(image_path).convert("RGB")
50
  inputs = self.processor(images=image, return_tensors="pt")
51
  with torch.no_grad():
52
- return self.vision_model.get_image_features(**inputs)
 
53
 
54
- def _generate_image_description(self, image_features):
55
  return "an image"
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def answer_query(self, query_text, image_path=None):
 
 
 
58
  if image_path:
59
- feats = self._get_image_features(image_path)
60
- img_desc = self._generate_image_description(feats)
61
- full_query = f"{query_text} {img_desc}"
62
  else:
63
- full_query = query_text
64
 
65
- result = self.qa_chain({"query": full_query})
 
66
  answer = result["result"]
67
- sources = [doc.metadata for doc in result.get("source_documents", [])]
 
68
  return answer, sources
69
 
 
70
 
71
- def run_rag(pdf_file, query, image_file=None):
72
  if pdf_file is None:
73
- return "Please upload a PDF.", []
74
-
75
- pdf_path = pdf_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  image_path = None
77
- if image_file:
78
- image_path = image_file.name
79
-
80
- rag = MultimodalRAG(pdf_path)
81
- answer, sources = rag.answer_query(query, image_path)
82
- return answer, sources
83
-
84
- iface = gr.Interface(
85
- fn=run_rag,
86
- inputs=[
87
- gr.File(label="PDF Document", file_types=[".pdf"]),
88
- gr.Textbox(label="Query", placeholder="Enter your question here..."),
89
- gr.File(label="Optional Image", file_types=[".png", ".jpg", ".jpeg"], optional=True)
90
- ],
91
- outputs=[
92
- gr.Textbox(label="Answer"),
93
- gr.JSON(label="Source Documents")
94
- ],
95
- title="Multimodal RAG QA",
96
- description="Upload a PDF, ask a question, optionally provide an image."
97
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  if __name__ == "__main__":
100
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import gradio as gr
2
  import os
3
  import torch
 
 
4
  from transformers import AutoProcessor, AutoModel
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
6
  from langchain_community.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
8
  from langchain_community.llms import HuggingFacePipeline
9
+ from PIL import Image
10
  from langchain_community.document_loaders import PyPDFLoader
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
 
 
 
13
  class MultimodalRAG:
14
+ def __init__(self, pdf_path=None):
15
  self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
  self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
17
  self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
18
+ self.pdf_path = pdf_path
19
+ self.documents = []
20
+ self.vector_store = None
21
+ self.retriever = None
22
+ self.qa_chain = None
23
+
24
  try:
25
  self.llm = HuggingFacePipeline.from_model_id(
26
  model_id="google/flan-t5-large",
27
  task="text2text-generation",
28
+ model_kwargs={"temperature": 0.7, "max_length": 512}
29
  )
30
+ except Exception as e:
31
+ print(f"Error loading flan-t5 model: {e}")
32
  from langchain.llms import OpenAI
33
  self.llm = OpenAI(temperature=0.7)
34
+
35
+ if pdf_path and os.path.exists(pdf_path):
36
+ self.load_pdf(pdf_path)
37
+
38
+ def load_pdf(self, pdf_path):
39
+ if not os.path.exists(pdf_path):
40
+ raise FileNotFoundError(f"PDF file not found: {pdf_path}")
41
+
42
+ loader = PyPDFLoader(pdf_path)
43
+ self.documents = loader.load()
44
 
45
+ text_splitter = RecursiveCharacterTextSplitter(
46
+ chunk_size=1000,
47
+ chunk_overlap=200
48
+ )
49
+ self.documents = text_splitter.split_documents(self.documents)
50
+
51
+ self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings)
52
+
53
  self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2})
54
+
55
  self.qa_chain = RetrievalQA.from_chain_type(
56
  llm=self.llm,
57
  chain_type="stuff",
58
  retriever=self.retriever,
59
  return_source_documents=True
60
  )
61
+
62
+ return f"Successfully loaded and processed PDF: {pdf_path}"
63
 
64
+ def process_image(self, image_path):
65
+ if not os.path.exists(image_path):
66
+ print(f"Warning: Image path {image_path} does not exist")
67
+ return None
 
68
 
69
+ image = Image.open(image_path)
 
70
  inputs = self.processor(images=image, return_tensors="pt")
71
  with torch.no_grad():
72
+ image_features = self.vision_model.get_image_features(**inputs)
73
+ return image_features
74
 
75
+ def generate_image_description(self, image_features):
76
  return "an image"
77
 
78
+ def retrieve_related_documents(self, query_text, image_path=None):
79
+ if image_path:
80
+ image_features = self.process_image(image_path)
81
+
82
+ if image_features is not None:
83
+ image_query = self.generate_image_description(image_features)
84
+
85
+ enhanced_query = f"{query_text} {image_query}"
86
+ else:
87
+ enhanced_query = query_text
88
+ else:
89
+ enhanced_query = query_text
90
+
91
+ docs = self.retriever.get_relevant_documents(enhanced_query)
92
+ return docs
93
+
94
  def answer_query(self, query_text, image_path=None):
95
+ if not self.vector_store or not self.qa_chain:
96
+ return "Please upload a PDF document first.", []
97
+
98
  if image_path:
99
+ docs = self.retrieve_related_documents(query_text, image_path)
 
 
100
  else:
101
+ docs = self.retrieve_related_documents(query_text)
102
 
103
+ result = self.qa_chain({"query": query_text})
104
+
105
  answer = result["result"]
106
+ sources = [doc.page_content[:1000] + "..." for doc in result["source_documents"]]
107
+
108
  return answer, sources
109
 
110
+ rag_system = MultimodalRAG()
111
 
112
+ def upload_pdf(pdf_file):
113
  if pdf_file is None:
114
+ return "No file uploaded"
115
+
116
+ file_path = pdf_file.name
117
+ try:
118
+ result = rag_system.load_pdf(file_path)
119
+ return result
120
+ except Exception as e:
121
+ return f"Error processing PDF: {str(e)}"
122
+
123
+ def save_image(image):
124
+ if image is None:
125
+ return None
126
+
127
+ temp_path = "temp_image.jpg"
128
+ image.save(temp_path)
129
+ return temp_path
130
+
131
+ def process_query(query, pdf_file, image=None):
132
+ if not query.strip():
133
+ return "Please enter a question", []
134
+
135
+ if pdf_file is None:
136
+ return "Please upload a PDF document first", []
137
+
138
  image_path = None
139
+ if image is not None:
140
+ image_path = save_image(image)
141
+
142
+ try:
143
+ answer, sources = rag_system.answer_query(query, image_path)
144
+ if image_path and os.path.exists(image_path):
145
+ os.remove(image_path)
146
+ return answer, sources
147
+ except Exception as e:
148
+ if image_path and os.path.exists(image_path):
149
+ os.remove(image_path)
150
+ return f"Error processing query: {str(e)}", []
151
+
152
+
153
+ with gr.Blocks(title="Multimodal RAG System") as demo:
154
+ gr.Markdown("# Multimodal RAG System")
155
+ gr.Markdown("Upload a PDF document and ask questions about it. You can also add an image for multimodal context.")
156
+
157
+ with gr.Row():
158
+ with gr.Column(scale=1):
159
+ pdf_input = gr.File(label="Upload PDF Document")
160
+ upload_button = gr.Button("Process PDF")
161
+ status_output = gr.Textbox(label="Status")
162
+
163
+ upload_button.click(
164
+ fn=upload_pdf,
165
+ inputs=[pdf_input],
166
+ outputs=[status_output]
167
+ )
168
+
169
+ with gr.Column(scale=2):
170
+ image_input = gr.Image(label="Optional: Upload an Image", type="pil")
171
+ query_input = gr.Textbox(label="Ask a question")
172
+ submit_button = gr.Button("Submit Question")
173
+
174
+ answer_output = gr.Textbox(label="Answer")
175
+ sources_output = gr.JSON(label="Sources")
176
+
177
+ submit_button.click(
178
+ fn=process_query,
179
+ inputs=[query_input, pdf_input, image_input],
180
+ outputs=[answer_output, sources_output]
181
+ )
182
 
183
  if __name__ == "__main__":
184
+ demo.launch(share=True, server_name="0.0.0.0")