qtAnswering / app.py
ikraamkb's picture
Update app.py
8e24199 verified
raw
history blame
5.37 kB
from fastapi import FastAPI, File, UploadFile
import pdfplumber
import docx
import openpyxl
from pptx import Presentation
import torch
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from PIL import Image
from transformers import pipeline
import gradio as gr
from fastapi.responses import RedirectResponse
import numpy as np
import easyocr
# Initialize FastAPI
app = FastAPI()
# Load AI Model for Question Answering (Proper Extractive QA Model)
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
# Initialize Translator for Multilingual Support
translator = pipeline("translation", model="facebook/m2m100_418M")
# Load Pretrained Object Detection Model (if needed)
model = fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
# Initialize OCR Model (Lazy Load)
reader = easyocr.Reader(["en"], gpu=True)
# Image Transformations
transform = transforms.Compose([
transforms.ToTensor()
])
# Allowed File Extensions
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
def validate_file_type(file):
ext = file.name.split(".")[-1].lower()
if ext not in ALLOWED_EXTENSIONS:
return f"Unsupported file format: {ext}"
return None
# Function to truncate text to 450 tokens
def truncate_text(text, max_tokens=450):
words = text.split()
return " ".join(words[:max_tokens])
# Text Extraction Functions
def extract_text_from_pdf(pdf_file):
text = ""
try:
with pdfplumber.open(pdf_file) as pdf:
for page in pdf.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
except Exception as e:
return f"Error reading PDF: {str(e)}"
return text.strip() if text else "No text found."
def extract_text_from_docx(docx_file):
try:
doc = docx.Document(docx_file)
return "\n".join([para.text for para in doc.paragraphs])
except Exception as e:
return f"Error reading DOCX: {str(e)}"
def extract_text_from_pptx(pptx_file):
try:
ppt = Presentation(pptx_file)
text = []
for slide in ppt.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text.append(shape.text)
return "\n".join(text) if text else "No text found."
except Exception as e:
return f"Error reading PPTX: {str(e)}"
def extract_text_from_excel(excel_file):
try:
wb = openpyxl.load_workbook(excel_file, read_only=True)
text = []
for sheet in wb.worksheets:
for row in sheet.iter_rows(values_only=True):
text.append(" ".join(map(str, row)))
return "\n".join(text) if text else "No text found."
except Exception as e:
return f"Error reading Excel: {str(e)}"
def extract_text_from_image(image_file):
image = Image.open(image_file).convert("RGB")
if np.array(image).std() < 10: # Low contrast = likely empty
return "No meaningful content detected in the image."
result = reader.readtext(np.array(image))
return " ".join([res[1] for res in result]) if result else "No text found."
def translate_text(text, target_lang="en"):
return translator(text, src_lang="auto", tgt_lang=target_lang)[0]["translation_text"]
# Function to answer questions based on document content
def answer_question_from_document(file, question):
validation_error = validate_file_type(file)
if validation_error:
return validation_error
file_ext = file.name.split(".")[-1].lower()
if file_ext == "pdf":
text = extract_text_from_pdf(file)
elif file_ext == "docx":
text = extract_text_from_docx(file)
elif file_ext == "pptx":
text = extract_text_from_pptx(file)
elif file_ext == "xlsx":
text = extract_text_from_excel(file)
else:
return "Unsupported file format!"
if not text:
return "No text extracted from the document."
text = translate_text(text) # Translate non-English text to English
truncated_text = truncate_text(text)
response = qa_pipeline({"question": question, "context": truncated_text})
return response["answer"]
def answer_question_from_image(image, question):
image_text = extract_text_from_image(image)
if not image_text:
return "No meaningful content detected in the image."
image_text = translate_text(image_text) # Translate non-English text to English
truncated_text = truncate_text(image_text)
response = qa_pipeline({"question": question, "context": truncated_text})
return response["answer"]
# Gradio UI for Document & Image QA
doc_interface = gr.Interface(
fn=answer_question_from_document,
inputs=[gr.File(label="Upload Document"), gr.Textbox(label="Ask a Question")],
outputs="text",
title="AI Document Question Answering"
)
img_interface = gr.Interface(
fn=answer_question_from_image,
inputs=[gr.Image(label="Upload Image"), gr.Textbox(label="Ask a Question")],
outputs="text",
title="AI Image Question Answering"
)
# Mount Gradio Interfaces
demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"])
app = gr.mount_gradio_app(app, demo, path="/")
@app.get("/")
def home():
return RedirectResponse(url="/")