ikraamkb commited on
Commit
5976e32
·
verified ·
1 Parent(s): 6218efd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -1
app.py CHANGED
@@ -201,4 +201,165 @@ app = gr.mount_gradio_app(app, demo, path="/")
201
  @app.get("/")
202
  def home():
203
  return RedirectResponse(url="/")
204
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  @app.get("/")
202
  def home():
203
  return RedirectResponse(url="/")
204
+ """
205
+ from fastapi import FastAPI
206
+ from fastapi.responses import RedirectResponse
207
+ import gradio as gr
208
+ from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
209
+ from PIL import Image
210
+ import torch
211
+ import fitz # PyMuPDF for PDF
212
+ import easyocr # OCR for images
213
+ import openpyxl # XLSX processing
214
+ import pptx # PPTX processing
215
+ import docx # DOCX processing
216
+
217
+ # Initialize FastAPI app
218
+ app = FastAPI()
219
+
220
+ # ========== Document QA Setup ==========
221
+ doc_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
222
+ doc_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
223
+
224
+ def read_pdf(file):
225
+ doc = fitz.open(stream=file.read(), filetype="pdf")
226
+ text = ""
227
+ for page in doc:
228
+ text += page.get_text()
229
+ return text
230
+
231
+ def answer_question_from_doc(file, question):
232
+ if file is None or not question.strip():
233
+ return "Please upload a document and ask a question."
234
+ text = read_pdf(file)
235
+ prompt = f"Context: {text}\nQuestion: {question}\nAnswer:"
236
+ inputs = doc_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
237
+ with torch.no_grad():
238
+ outputs = doc_model.generate(**inputs, max_new_tokens=100)
239
+ answer = doc_tokenizer.decode(outputs[0], skip_special_tokens=True)
240
+ return answer.split("Answer:")[-1].strip()
241
+
242
+ # ========== Image QA Setup ==========
243
+ vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
244
+ vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
245
+
246
+ def answer_question_from_image(image, question):
247
+ if image is None or not question.strip():
248
+ return "Please upload an image and ask a question."
249
+ inputs = vqa_processor(image, question, return_tensors="pt")
250
+ with torch.no_grad():
251
+ outputs = vqa_model(**inputs)
252
+ predicted_id = outputs.logits.argmax(-1).item()
253
+ return vqa_model.config.id2label[predicted_id]
254
+
255
+ # ========== Text Extraction Functions ==========
256
+ reader = easyocr.Reader(['en', 'fr']) # OCR for English & French
257
+
258
+ def extract_text_from_pdf(pdf_file):
259
+ """Extract text from a PDF file."""
260
+ text = []
261
+ try:
262
+ with fitz.open(pdf_file) as doc:
263
+ for page in doc:
264
+ text.append(page.get_text("text"))
265
+ except Exception as e:
266
+ return f"Error reading PDF: {e}"
267
+ return "\n".join(text)
268
+
269
+ def extract_text_from_docx(docx_file):
270
+ """Extract text from a DOCX file."""
271
+ doc = docx.Document(docx_file)
272
+ return "\n".join([p.text for p in doc.paragraphs if p.text.strip()])
273
+
274
+ def extract_text_from_pptx(pptx_file):
275
+ """Extract text from a PPTX file."""
276
+ text = []
277
+ try:
278
+ presentation = pptx.Presentation(pptx_file)
279
+ for slide in presentation.slides:
280
+ for shape in slide.shapes:
281
+ if hasattr(shape, "text"):
282
+ text.append(shape.text)
283
+ except Exception as e:
284
+ return f"Error reading PPTX: {e}"
285
+ return "\n".join(text)
286
+
287
+ def extract_text_from_xlsx(xlsx_file):
288
+ """Extract text from an XLSX file."""
289
+ text = []
290
+ try:
291
+ wb = openpyxl.load_workbook(xlsx_file)
292
+ for sheet in wb.sheetnames:
293
+ ws = wb[sheet]
294
+ for row in ws.iter_rows(values_only=True):
295
+ text.append(" ".join(str(cell) for cell in row if cell))
296
+ except Exception as e:
297
+ return f"Error reading XLSX: {e}"
298
+ return "\n".join(text)
299
+
300
+ def extract_text_from_image(image_path):
301
+ """Extract text from an image using EasyOCR."""
302
+ result = reader.readtext(image_path, detail=0)
303
+ return " ".join(result) # Return text as a single string
304
+
305
+ # ========== Main Processing Functions ==========
306
+ def answer_question_from_doc(file, question):
307
+ """Process document and answer a question based on its content."""
308
+ ext = file.name.split(".")[-1].lower()
309
+
310
+ if ext == "pdf":
311
+ context = extract_text_from_pdf(file.name)
312
+ elif ext == "docx":
313
+ context = extract_text_from_docx(file.name)
314
+ elif ext == "pptx":
315
+ context = extract_text_from_pptx(file.name)
316
+ elif ext == "xlsx":
317
+ context = extract_text_from_xlsx(file.name)
318
+ else:
319
+ return "Unsupported file format."
320
+
321
+ if not context.strip():
322
+ return "No text found in the document."
323
+
324
+ # Generate answer using QA pipeline correctly
325
+ try:
326
+ result = qa_model({"question": question, "context": context})
327
+ return result["answer"]
328
+ except Exception as e:
329
+ return f"Error generating answer: {e}"
330
+
331
+ def answer_question_from_image(image, question):
332
+ """Process an image, extract text, and answer a question."""
333
+ img_text = extract_text_from_image(image)
334
+ if not img_text.strip():
335
+ return "No readable text found in the image."
336
+ try:
337
+ result = qa_model({"question": question, "context": img_text})
338
+ return result["answer"]
339
+ except Exception as e:
340
+ return f"Error generating answer: {e}"
341
+
342
+ # ========== Gradio Interfaces ==========
343
+ with gr.Blocks() as doc_interface:
344
+ gr.Markdown("## 📄 Document Question Answering")
345
+ file_input = gr.File(label="Upload DOCX, PPTX, XLSX, or PDF")
346
+ question_input = gr.Textbox(label="Ask a question")
347
+ answer_output = gr.Textbox(label="Answer")
348
+ file_submit = gr.Button("Get Answer")
349
+ file_submit.click(answer_question_from_doc, inputs=[file_input, question_input], outputs=answer_output)
350
+
351
+ with gr.Blocks() as img_interface:
352
+ gr.Markdown("## 🖼️ Image Question Answering")
353
+ image_input = gr.Image(label="Upload an Image")
354
+ img_question_input = gr.Textbox(label="Ask a question")
355
+ img_answer_output = gr.Textbox(label="Answer")
356
+ image_submit = gr.Button("Get Answer")
357
+ image_submit.click(answer_question_from_image, inputs=[image_input, img_question_input], outputs=img_answer_output)
358
+
359
+ # ========== Mount Gradio App ==========
360
+ demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"])
361
+ app = gr.mount_gradio_app(app, demo, path="/")
362
+
363
+ @app.get("/")
364
+ def home():
365
+ return RedirectResponse(url="/")