ayyuce commited on
Commit
ae8cf80
·
verified ·
1 Parent(s): b701d2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -137
app.py CHANGED
@@ -1,184 +1,100 @@
1
- import gradio as gr
2
  import os
3
  import torch
 
 
4
  from transformers import AutoProcessor, AutoModel
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
6
  from langchain_community.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
8
  from langchain_community.llms import HuggingFacePipeline
9
- from PIL import Image
10
  from langchain_community.document_loaders import PyPDFLoader
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
 
 
 
13
  class MultimodalRAG:
14
- def __init__(self, pdf_path=None):
15
  self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
  self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
17
  self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
18
- self.pdf_path = pdf_path
19
- self.documents = []
20
- self.vector_store = None
21
- self.retriever = None
22
- self.qa_chain = None
23
-
24
  try:
25
  self.llm = HuggingFacePipeline.from_model_id(
26
  model_id="google/flan-t5-large",
27
  task="text2text-generation",
28
- model_kwargs={"temperature": 0.7, "max_length": 512}
29
  )
30
- except Exception as e:
31
- print(f"Error loading flan-t5 model: {e}")
32
  from langchain.llms import OpenAI
33
  self.llm = OpenAI(temperature=0.7)
34
-
35
- if pdf_path and os.path.exists(pdf_path):
36
- self.load_pdf(pdf_path)
37
-
38
- def load_pdf(self, pdf_path):
39
- if not os.path.exists(pdf_path):
40
- raise FileNotFoundError(f"PDF file not found: {pdf_path}")
41
-
42
- loader = PyPDFLoader(pdf_path)
43
- self.documents = loader.load()
44
 
45
- text_splitter = RecursiveCharacterTextSplitter(
46
- chunk_size=1000,
47
- chunk_overlap=200
48
- )
49
- self.documents = text_splitter.split_documents(self.documents)
50
-
51
- self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings)
52
-
53
  self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2})
54
-
55
  self.qa_chain = RetrievalQA.from_chain_type(
56
  llm=self.llm,
57
  chain_type="stuff",
58
  retriever=self.retriever,
59
  return_source_documents=True
60
  )
61
-
62
- return f"Successfully loaded and processed PDF: {pdf_path}"
63
 
64
- def process_image(self, image_path):
65
- if not os.path.exists(image_path):
66
- print(f"Warning: Image path {image_path} does not exist")
67
- return None
 
68
 
69
- image = Image.open(image_path)
 
70
  inputs = self.processor(images=image, return_tensors="pt")
71
  with torch.no_grad():
72
- image_features = self.vision_model.get_image_features(**inputs)
73
- return image_features
74
-
75
- def generate_image_description(self, image_features):
76
- return "a photo"
77
-
78
- def retrieve_related_documents(self, query_text, image_path=None):
79
- if image_path:
80
- image_features = self.process_image(image_path)
81
-
82
- if image_features is not None:
83
- image_query = self.generate_image_description(image_features)
84
-
85
- enhanced_query = f"{query_text} {image_query}"
86
- else:
87
- enhanced_query = query_text
88
- else:
89
- enhanced_query = query_text
90
 
91
- docs = self.retriever.get_relevant_documents(enhanced_query)
92
- return docs
93
 
94
  def answer_query(self, query_text, image_path=None):
95
- if not self.vector_store or not self.qa_chain:
96
- return "Please upload a PDF document first."
97
-
98
  if image_path:
99
- docs = self.retrieve_related_documents(query_text, image_path)
 
 
100
  else:
101
- docs = self.retrieve_related_documents(query_text)
102
 
103
- result = self.qa_chain({"query": query_text})
104
-
105
  answer = result["result"]
106
- sources = [doc.page_content[:1000] + "..." for doc in result["source_documents"]]
107
-
108
  return answer, sources
109
 
110
- rag_system = MultimodalRAG()
111
 
112
- def upload_pdf(pdf_file):
113
- if pdf_file is None:
114
- return "No file uploaded"
115
-
116
- file_path = pdf_file.name
117
- try:
118
- result = rag_system.load_pdf(file_path)
119
- return result
120
- except Exception as e:
121
- return f"Error processing PDF: {str(e)}"
122
-
123
- def save_image(image):
124
- if image is None:
125
- return None
126
-
127
- temp_path = "temp_image.jpg"
128
- image.save(temp_path)
129
- return temp_path
130
-
131
- def process_query(query, pdf_file, image=None):
132
- if not query.strip():
133
- return "Please enter a question", []
134
-
135
  if pdf_file is None:
136
- return "Please upload a PDF document first", []
137
-
 
138
  image_path = None
139
- if image is not None:
140
- image_path = save_image(image)
141
-
142
- try:
143
- answer, sources = rag_system.answer_query(query, image_path)
144
- if image_path and os.path.exists(image_path):
145
- os.remove(image_path)
146
- return answer, sources
147
- except Exception as e:
148
- if image_path and os.path.exists(image_path):
149
- os.remove(image_path)
150
- return f"Error processing query: {str(e)}", []
151
-
152
- # Create Gradio interface
153
- with gr.Blocks(title="Multimodal RAG System") as demo:
154
- gr.Markdown("# Multimodal RAG System")
155
- gr.Markdown("Upload a PDF document and ask questions about it. You can also add an image for multimodal context.")
156
-
157
- with gr.Row():
158
- with gr.Column(scale=1):
159
- pdf_input = gr.File(label="Upload PDF Document")
160
- upload_button = gr.Button("Process PDF")
161
- status_output = gr.Textbox(label="Status")
162
-
163
- upload_button.click(
164
- fn=upload_pdf,
165
- inputs=[pdf_input],
166
- outputs=[status_output]
167
- )
168
-
169
- with gr.Column(scale=2):
170
- image_input = gr.Image(label="Optional: Upload an Image", type="pil")
171
- query_input = gr.Textbox(label="Ask a question")
172
- submit_button = gr.Button("Submit Question")
173
-
174
- answer_output = gr.Textbox(label="Answer")
175
- sources_output = gr.JSON(label="Sources")
176
-
177
- submit_button.click(
178
- fn=process_query,
179
- inputs=[query_input, pdf_input, image_input],
180
- outputs=[answer_output, sources_output]
181
- )
182
 
183
  if __name__ == "__main__":
184
- demo.launch(share=True, server_name="0.0.0.0")
 
 
1
  import os
2
  import torch
3
+ from PIL import Image
4
+ import gradio as gr
5
  from transformers import AutoProcessor, AutoModel
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_community.vectorstores import FAISS
8
  from langchain.chains import RetrievalQA
9
  from langchain_community.llms import HuggingFacePipeline
 
10
  from langchain_community.document_loaders import PyPDFLoader
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
 
13
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
14
+
15
  class MultimodalRAG:
16
+ def __init__(self, pdf_path):
17
  self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
18
  self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
19
  self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
20
+
21
+ self.documents = self._load_and_split(pdf_path)
22
+ self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings)
23
+
 
 
24
  try:
25
  self.llm = HuggingFacePipeline.from_model_id(
26
  model_id="google/flan-t5-large",
27
  task="text2text-generation",
28
+ model_kwargs={"temperature": 0.7, "max_length": 512, "device": -1}
29
  )
30
+ except Exception:
 
31
  from langchain.llms import OpenAI
32
  self.llm = OpenAI(temperature=0.7)
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
 
 
 
 
 
34
  self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2})
 
35
  self.qa_chain = RetrievalQA.from_chain_type(
36
  llm=self.llm,
37
  chain_type="stuff",
38
  retriever=self.retriever,
39
  return_source_documents=True
40
  )
 
 
41
 
42
+ def _load_and_split(self, pdf_path):
43
+ loader = PyPDFLoader(pdf_path)
44
+ docs = loader.load()
45
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
46
+ return splitter.split_documents(docs)
47
 
48
+ def _get_image_features(self, image_path):
49
+ image = Image.open(image_path).convert("RGB")
50
  inputs = self.processor(images=image, return_tensors="pt")
51
  with torch.no_grad():
52
+ return self.vision_model.get_image_features(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ def _generate_image_description(self, image_features):
55
+ return "an image"
56
 
57
  def answer_query(self, query_text, image_path=None):
 
 
 
58
  if image_path:
59
+ feats = self._get_image_features(image_path)
60
+ img_desc = self._generate_image_description(feats)
61
+ full_query = f"{query_text} {img_desc}"
62
  else:
63
+ full_query = query_text
64
 
65
+ result = self.qa_chain({"query": full_query})
 
66
  answer = result["result"]
67
+ sources = [doc.metadata for doc in result.get("source_documents", [])]
 
68
  return answer, sources
69
 
 
70
 
71
+ def run_rag(pdf_file, query, image_file=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  if pdf_file is None:
73
+ return "Please upload a PDF.", []
74
+
75
+ pdf_path = pdf_file.name
76
  image_path = None
77
+ if image_file:
78
+ image_path = image_file.name
79
+
80
+ rag = MultimodalRAG(pdf_path)
81
+ answer, sources = rag.answer_query(query, image_path)
82
+ return answer, sources
83
+
84
+ iface = gr.Interface(
85
+ fn=run_rag,
86
+ inputs=[
87
+ gr.File(label="PDF Document", file_types=[".pdf"]),
88
+ gr.Textbox(label="Query", placeholder="Enter your question here..."),
89
+ gr.File(label="Optional Image", file_types=[".png", ".jpg", ".jpeg"], optional=True)
90
+ ],
91
+ outputs=[
92
+ gr.Textbox(label="Answer"),
93
+ gr.JSON(label="Source Documents")
94
+ ],
95
+ title="Multimodal RAG QA",
96
+ description="Upload a PDF, ask a question, optionally provide an image."
97
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  if __name__ == "__main__":
100
+ iface.launch(server_name="0.0.0.0", server_port=7860)