ikraamkb commited on
Commit
1cafb18
·
verified ·
1 Parent(s): 5976e32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -79
app.py CHANGED
@@ -205,74 +205,40 @@ def home():
205
  from fastapi import FastAPI
206
  from fastapi.responses import RedirectResponse
207
  import gradio as gr
208
- from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
209
- from PIL import Image
210
- import torch
211
- import fitz # PyMuPDF for PDF
212
  import easyocr # OCR for images
213
  import openpyxl # XLSX processing
214
  import pptx # PPTX processing
215
  import docx # DOCX processing
216
 
217
- # Initialize FastAPI app
218
- app = FastAPI()
219
-
220
- # ========== Document QA Setup ==========
221
- doc_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
222
- doc_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
223
 
224
- def read_pdf(file):
225
- doc = fitz.open(stream=file.read(), filetype="pdf")
226
- text = ""
227
- for page in doc:
228
- text += page.get_text()
229
- return text
230
 
231
- def answer_question_from_doc(file, question):
232
- if file is None or not question.strip():
233
- return "Please upload a document and ask a question."
234
- text = read_pdf(file)
235
- prompt = f"Context: {text}\nQuestion: {question}\nAnswer:"
236
- inputs = doc_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
237
- with torch.no_grad():
238
- outputs = doc_model.generate(**inputs, max_new_tokens=100)
239
- answer = doc_tokenizer.decode(outputs[0], skip_special_tokens=True)
240
- return answer.split("Answer:")[-1].strip()
241
 
242
- # ========== Image QA Setup ==========
243
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
244
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
245
 
246
- def answer_question_from_image(image, question):
247
- if image is None or not question.strip():
248
- return "Please upload an image and ask a question."
249
- inputs = vqa_processor(image, question, return_tensors="pt")
250
- with torch.no_grad():
251
- outputs = vqa_model(**inputs)
252
- predicted_id = outputs.logits.argmax(-1).item()
253
- return vqa_model.config.id2label[predicted_id]
254
 
255
- # ========== Text Extraction Functions ==========
256
- reader = easyocr.Reader(['en', 'fr']) # OCR for English & French
257
-
258
- def extract_text_from_pdf(pdf_file):
259
- """Extract text from a PDF file."""
260
- text = []
261
- try:
262
- with fitz.open(pdf_file) as doc:
263
- for page in doc:
264
- text.append(page.get_text("text"))
265
- except Exception as e:
266
- return f"Error reading PDF: {e}"
267
- return "\n".join(text)
268
 
269
  def extract_text_from_docx(docx_file):
270
- """Extract text from a DOCX file."""
271
  doc = docx.Document(docx_file)
272
  return "\n".join([p.text for p in doc.paragraphs if p.text.strip()])
273
 
274
  def extract_text_from_pptx(pptx_file):
275
- """Extract text from a PPTX file."""
276
  text = []
277
  try:
278
  presentation = pptx.Presentation(pptx_file)
@@ -285,7 +251,6 @@ def extract_text_from_pptx(pptx_file):
285
  return "\n".join(text)
286
 
287
  def extract_text_from_xlsx(xlsx_file):
288
- """Extract text from an XLSX file."""
289
  text = []
290
  try:
291
  wb = openpyxl.load_workbook(xlsx_file)
@@ -297,39 +262,39 @@ def extract_text_from_xlsx(xlsx_file):
297
  return f"Error reading XLSX: {e}"
298
  return "\n".join(text)
299
 
 
300
  def extract_text_from_image(image_path):
301
- """Extract text from an image using EasyOCR."""
302
  result = reader.readtext(image_path, detail=0)
303
- return " ".join(result) # Return text as a single string
304
 
305
- # ========== Main Processing Functions ==========
306
  def answer_question_from_doc(file, question):
307
- """Process document and answer a question based on its content."""
308
- ext = file.name.split(".")[-1].lower()
309
 
 
310
  if ext == "pdf":
311
- context = extract_text_from_pdf(file.name)
312
  elif ext == "docx":
313
- context = extract_text_from_docx(file.name)
314
  elif ext == "pptx":
315
- context = extract_text_from_pptx(file.name)
316
  elif ext == "xlsx":
317
- context = extract_text_from_xlsx(file.name)
318
  else:
319
  return "Unsupported file format."
320
 
321
  if not context.strip():
322
  return "No text found in the document."
323
-
324
- # Generate answer using QA pipeline correctly
325
  try:
326
  result = qa_model({"question": question, "context": context})
327
  return result["answer"]
328
  except Exception as e:
329
  return f"Error generating answer: {e}"
330
 
331
- def answer_question_from_image(image, question):
332
- """Process an image, extract text, and answer a question."""
333
  img_text = extract_text_from_image(image)
334
  if not img_text.strip():
335
  return "No readable text found in the image."
@@ -339,27 +304,43 @@ def answer_question_from_image(image, question):
339
  except Exception as e:
340
  return f"Error generating answer: {e}"
341
 
342
- # ========== Gradio Interfaces ==========
 
 
 
 
 
 
 
 
 
 
343
  with gr.Blocks() as doc_interface:
344
  gr.Markdown("## 📄 Document Question Answering")
345
  file_input = gr.File(label="Upload DOCX, PPTX, XLSX, or PDF")
346
- question_input = gr.Textbox(label="Ask a question")
347
  answer_output = gr.Textbox(label="Answer")
348
  file_submit = gr.Button("Get Answer")
349
- file_submit.click(answer_question_from_doc, inputs=[file_input, question_input], outputs=answer_output)
350
-
351
- with gr.Blocks() as img_interface:
352
- gr.Markdown("## 🖼️ Image Question Answering")
353
- image_input = gr.Image(label="Upload an Image")
354
- img_question_input = gr.Textbox(label="Ask a question")
355
- img_answer_output = gr.Textbox(label="Answer")
356
- image_submit = gr.Button("Get Answer")
357
- image_submit.click(answer_question_from_image, inputs=[image_input, img_question_input], outputs=img_answer_output)
358
-
359
- # ========== Mount Gradio App ==========
360
- demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"])
 
 
 
 
 
 
361
  app = gr.mount_gradio_app(app, demo, path="/")
362
 
363
  @app.get("/")
364
- def home():
365
  return RedirectResponse(url="/")
 
205
  from fastapi import FastAPI
206
  from fastapi.responses import RedirectResponse
207
  import gradio as gr
208
+
209
+ import fitz # PyMuPDF for PDFs
 
 
210
  import easyocr # OCR for images
211
  import openpyxl # XLSX processing
212
  import pptx # PPTX processing
213
  import docx # DOCX processing
214
 
215
+ from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering
216
+ from PIL import Image
217
+ import torch
 
 
 
218
 
219
+ # === Initialize FastAPI App ===
220
+ app = FastAPI()
 
 
 
 
221
 
222
+ # === Initialize QA Model for Documents and OCR ===
223
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
 
 
 
 
 
 
 
 
224
 
225
+ # === Initialize Image QA Model (VQA) ===
226
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
227
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
228
 
229
+ # === OCR Reader ===
230
+ reader = easyocr.Reader(['en', 'fr'])
 
 
 
 
 
 
231
 
232
+ # === Document Text Extraction Functions ===
233
+ def extract_text_from_pdf(file_obj):
234
+ doc = fitz.open(stream=file_obj.read(), filetype="pdf")
235
+ return "\n".join([page.get_text() for page in doc])
 
 
 
 
 
 
 
 
 
236
 
237
  def extract_text_from_docx(docx_file):
 
238
  doc = docx.Document(docx_file)
239
  return "\n".join([p.text for p in doc.paragraphs if p.text.strip()])
240
 
241
  def extract_text_from_pptx(pptx_file):
 
242
  text = []
243
  try:
244
  presentation = pptx.Presentation(pptx_file)
 
251
  return "\n".join(text)
252
 
253
  def extract_text_from_xlsx(xlsx_file):
 
254
  text = []
255
  try:
256
  wb = openpyxl.load_workbook(xlsx_file)
 
262
  return f"Error reading XLSX: {e}"
263
  return "\n".join(text)
264
 
265
+ # === Image OCR ===
266
  def extract_text_from_image(image_path):
 
267
  result = reader.readtext(image_path, detail=0)
268
+ return " ".join(result)
269
 
270
+ # === QA for Document Files ===
271
  def answer_question_from_doc(file, question):
272
+ if file is None or not question.strip():
273
+ return "Please upload a document and ask a question."
274
 
275
+ ext = file.name.split(".")[-1].lower()
276
  if ext == "pdf":
277
+ context = extract_text_from_pdf(file)
278
  elif ext == "docx":
279
+ context = extract_text_from_docx(file)
280
  elif ext == "pptx":
281
+ context = extract_text_from_pptx(file)
282
  elif ext == "xlsx":
283
+ context = extract_text_from_xlsx(file)
284
  else:
285
  return "Unsupported file format."
286
 
287
  if not context.strip():
288
  return "No text found in the document."
289
+
 
290
  try:
291
  result = qa_model({"question": question, "context": context})
292
  return result["answer"]
293
  except Exception as e:
294
  return f"Error generating answer: {e}"
295
 
296
+ # === QA for Images using EasyOCR and QA model ===
297
+ def answer_question_from_image_text(image, question):
298
  img_text = extract_text_from_image(image)
299
  if not img_text.strip():
300
  return "No readable text found in the image."
 
304
  except Exception as e:
305
  return f"Error generating answer: {e}"
306
 
307
+ # === QA for Images using ViLT (Visual QA Model) ===
308
+ def answer_question_from_image_visual(image, question):
309
+ if image is None or not question.strip():
310
+ return "Please upload an image and ask a question."
311
+ inputs = vqa_processor(image, question, return_tensors="pt")
312
+ with torch.no_grad():
313
+ outputs = vqa_model(**inputs)
314
+ predicted_id = outputs.logits.argmax(-1).item()
315
+ return vqa_model.config.id2label[predicted_id]
316
+
317
+ # === Gradio Interfaces ===
318
  with gr.Blocks() as doc_interface:
319
  gr.Markdown("## 📄 Document Question Answering")
320
  file_input = gr.File(label="Upload DOCX, PPTX, XLSX, or PDF")
321
+ question_input = gr.Textbox(label="Ask a Question")
322
  answer_output = gr.Textbox(label="Answer")
323
  file_submit = gr.Button("Get Answer")
324
+ file_submit.click(fn=answer_question_from_doc, inputs=[file_input, question_input], outputs=answer_output)
325
+
326
+ with gr.Blocks() as image_interface:
327
+ gr.Markdown("## 🖼️ Image Question Answering (OCR + VQA)")
328
+ with gr.Tabs():
329
+ with gr.TabItem("OCR-based Image QA"):
330
+ image_input = gr.Image(label="Upload Image")
331
+ img_question_input = gr.Textbox(label="Ask a Question")
332
+ img_answer_output = gr.Textbox(label="Answer")
333
+ gr.Button("Get Answer").click(fn=answer_question_from_image_text, inputs=[image_input, img_question_input], outputs=img_answer_output)
334
+ with gr.TabItem("Visual QA (ViLT)"):
335
+ image_input_vqa = gr.Image(label="Upload Image")
336
+ vqa_question_input = gr.Textbox(label="Ask a Question")
337
+ vqa_answer_output = gr.Textbox(label="Answer")
338
+ gr.Button("Get Answer").click(fn=answer_question_from_image_visual, inputs=[image_input_vqa, vqa_question_input], outputs=vqa_answer_output)
339
+
340
+ # === Mount Gradio on FastAPI ===
341
+ demo = gr.TabbedInterface([doc_interface, image_interface], ["Document QA", "Image QA"])
342
  app = gr.mount_gradio_app(app, demo, path="/")
343
 
344
  @app.get("/")
345
+ def root():
346
  return RedirectResponse(url="/")