Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -201,4 +201,165 @@ app = gr.mount_gradio_app(app, demo, path="/")
|
|
201 |
@app.get("/")
|
202 |
def home():
|
203 |
return RedirectResponse(url="/")
|
204 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
@app.get("/")
|
202 |
def home():
|
203 |
return RedirectResponse(url="/")
|
204 |
+
"""
|
205 |
+
from fastapi import FastAPI
|
206 |
+
from fastapi.responses import RedirectResponse
|
207 |
+
import gradio as gr
|
208 |
+
from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
|
209 |
+
from PIL import Image
|
210 |
+
import torch
|
211 |
+
import fitz # PyMuPDF for PDF
|
212 |
+
import easyocr # OCR for images
|
213 |
+
import openpyxl # XLSX processing
|
214 |
+
import pptx # PPTX processing
|
215 |
+
import docx # DOCX processing
|
216 |
+
|
217 |
+
# Initialize FastAPI app
|
218 |
+
app = FastAPI()
|
219 |
+
|
220 |
+
# ========== Document QA Setup ==========
|
221 |
+
doc_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
222 |
+
doc_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
223 |
+
|
224 |
+
def read_pdf(file):
|
225 |
+
doc = fitz.open(stream=file.read(), filetype="pdf")
|
226 |
+
text = ""
|
227 |
+
for page in doc:
|
228 |
+
text += page.get_text()
|
229 |
+
return text
|
230 |
+
|
231 |
+
def answer_question_from_doc(file, question):
|
232 |
+
if file is None or not question.strip():
|
233 |
+
return "Please upload a document and ask a question."
|
234 |
+
text = read_pdf(file)
|
235 |
+
prompt = f"Context: {text}\nQuestion: {question}\nAnswer:"
|
236 |
+
inputs = doc_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
|
237 |
+
with torch.no_grad():
|
238 |
+
outputs = doc_model.generate(**inputs, max_new_tokens=100)
|
239 |
+
answer = doc_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
240 |
+
return answer.split("Answer:")[-1].strip()
|
241 |
+
|
242 |
+
# ========== Image QA Setup ==========
|
243 |
+
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
244 |
+
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
245 |
+
|
246 |
+
def answer_question_from_image(image, question):
|
247 |
+
if image is None or not question.strip():
|
248 |
+
return "Please upload an image and ask a question."
|
249 |
+
inputs = vqa_processor(image, question, return_tensors="pt")
|
250 |
+
with torch.no_grad():
|
251 |
+
outputs = vqa_model(**inputs)
|
252 |
+
predicted_id = outputs.logits.argmax(-1).item()
|
253 |
+
return vqa_model.config.id2label[predicted_id]
|
254 |
+
|
255 |
+
# ========== Text Extraction Functions ==========
|
256 |
+
reader = easyocr.Reader(['en', 'fr']) # OCR for English & French
|
257 |
+
|
258 |
+
def extract_text_from_pdf(pdf_file):
|
259 |
+
"""Extract text from a PDF file."""
|
260 |
+
text = []
|
261 |
+
try:
|
262 |
+
with fitz.open(pdf_file) as doc:
|
263 |
+
for page in doc:
|
264 |
+
text.append(page.get_text("text"))
|
265 |
+
except Exception as e:
|
266 |
+
return f"Error reading PDF: {e}"
|
267 |
+
return "\n".join(text)
|
268 |
+
|
269 |
+
def extract_text_from_docx(docx_file):
|
270 |
+
"""Extract text from a DOCX file."""
|
271 |
+
doc = docx.Document(docx_file)
|
272 |
+
return "\n".join([p.text for p in doc.paragraphs if p.text.strip()])
|
273 |
+
|
274 |
+
def extract_text_from_pptx(pptx_file):
|
275 |
+
"""Extract text from a PPTX file."""
|
276 |
+
text = []
|
277 |
+
try:
|
278 |
+
presentation = pptx.Presentation(pptx_file)
|
279 |
+
for slide in presentation.slides:
|
280 |
+
for shape in slide.shapes:
|
281 |
+
if hasattr(shape, "text"):
|
282 |
+
text.append(shape.text)
|
283 |
+
except Exception as e:
|
284 |
+
return f"Error reading PPTX: {e}"
|
285 |
+
return "\n".join(text)
|
286 |
+
|
287 |
+
def extract_text_from_xlsx(xlsx_file):
|
288 |
+
"""Extract text from an XLSX file."""
|
289 |
+
text = []
|
290 |
+
try:
|
291 |
+
wb = openpyxl.load_workbook(xlsx_file)
|
292 |
+
for sheet in wb.sheetnames:
|
293 |
+
ws = wb[sheet]
|
294 |
+
for row in ws.iter_rows(values_only=True):
|
295 |
+
text.append(" ".join(str(cell) for cell in row if cell))
|
296 |
+
except Exception as e:
|
297 |
+
return f"Error reading XLSX: {e}"
|
298 |
+
return "\n".join(text)
|
299 |
+
|
300 |
+
def extract_text_from_image(image_path):
|
301 |
+
"""Extract text from an image using EasyOCR."""
|
302 |
+
result = reader.readtext(image_path, detail=0)
|
303 |
+
return " ".join(result) # Return text as a single string
|
304 |
+
|
305 |
+
# ========== Main Processing Functions ==========
|
306 |
+
def answer_question_from_doc(file, question):
|
307 |
+
"""Process document and answer a question based on its content."""
|
308 |
+
ext = file.name.split(".")[-1].lower()
|
309 |
+
|
310 |
+
if ext == "pdf":
|
311 |
+
context = extract_text_from_pdf(file.name)
|
312 |
+
elif ext == "docx":
|
313 |
+
context = extract_text_from_docx(file.name)
|
314 |
+
elif ext == "pptx":
|
315 |
+
context = extract_text_from_pptx(file.name)
|
316 |
+
elif ext == "xlsx":
|
317 |
+
context = extract_text_from_xlsx(file.name)
|
318 |
+
else:
|
319 |
+
return "Unsupported file format."
|
320 |
+
|
321 |
+
if not context.strip():
|
322 |
+
return "No text found in the document."
|
323 |
+
|
324 |
+
# Generate answer using QA pipeline correctly
|
325 |
+
try:
|
326 |
+
result = qa_model({"question": question, "context": context})
|
327 |
+
return result["answer"]
|
328 |
+
except Exception as e:
|
329 |
+
return f"Error generating answer: {e}"
|
330 |
+
|
331 |
+
def answer_question_from_image(image, question):
|
332 |
+
"""Process an image, extract text, and answer a question."""
|
333 |
+
img_text = extract_text_from_image(image)
|
334 |
+
if not img_text.strip():
|
335 |
+
return "No readable text found in the image."
|
336 |
+
try:
|
337 |
+
result = qa_model({"question": question, "context": img_text})
|
338 |
+
return result["answer"]
|
339 |
+
except Exception as e:
|
340 |
+
return f"Error generating answer: {e}"
|
341 |
+
|
342 |
+
# ========== Gradio Interfaces ==========
|
343 |
+
with gr.Blocks() as doc_interface:
|
344 |
+
gr.Markdown("## 📄 Document Question Answering")
|
345 |
+
file_input = gr.File(label="Upload DOCX, PPTX, XLSX, or PDF")
|
346 |
+
question_input = gr.Textbox(label="Ask a question")
|
347 |
+
answer_output = gr.Textbox(label="Answer")
|
348 |
+
file_submit = gr.Button("Get Answer")
|
349 |
+
file_submit.click(answer_question_from_doc, inputs=[file_input, question_input], outputs=answer_output)
|
350 |
+
|
351 |
+
with gr.Blocks() as img_interface:
|
352 |
+
gr.Markdown("## 🖼️ Image Question Answering")
|
353 |
+
image_input = gr.Image(label="Upload an Image")
|
354 |
+
img_question_input = gr.Textbox(label="Ask a question")
|
355 |
+
img_answer_output = gr.Textbox(label="Answer")
|
356 |
+
image_submit = gr.Button("Get Answer")
|
357 |
+
image_submit.click(answer_question_from_image, inputs=[image_input, img_question_input], outputs=img_answer_output)
|
358 |
+
|
359 |
+
# ========== Mount Gradio App ==========
|
360 |
+
demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"])
|
361 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
362 |
+
|
363 |
+
@app.get("/")
|
364 |
+
def home():
|
365 |
+
return RedirectResponse(url="/")
|