Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import fitz # PyMuPDF | |
import tika | |
import torch | |
from fastapi import FastAPI | |
from transformers import pipeline | |
from PIL import Image | |
from io import BytesIO | |
from starlette.responses import RedirectResponse | |
from tika import parser | |
from openpyxl import load_workbook | |
# Initialize Tika for DOCX & PPTX parsing (Ensure Java is installed) | |
tika.initVM() | |
# Initialize FastAPI | |
app = FastAPI() | |
# Load models | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
qa_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device) | |
image_captioning_pipeline = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"} | |
# β Function to Validate File Type | |
def validate_file_type(file): | |
if hasattr(file, "name"): | |
ext = file.name.split(".")[-1].lower() | |
if ext not in ALLOWED_EXTENSIONS: | |
return f"β Unsupported file format: {ext}" | |
return None | |
return "β Invalid file format!" | |
# β Extract Text from PDF | |
def extract_text_from_pdf(file): | |
with fitz.open(file.name) as doc: | |
return "\n".join([page.get_text() for page in doc]) | |
# β Extract Text from DOCX & PPTX using Tika | |
def extract_text_with_tika(file): | |
return parser.from_file(file.name)["content"] | |
# β Extract Text from Excel | |
def extract_text_from_excel(file): | |
wb = load_workbook(file.name, data_only=True) | |
text = [] | |
for sheet in wb.worksheets: | |
for row in sheet.iter_rows(values_only=True): | |
text.append(" ".join(str(cell) for cell in row if cell)) | |
return "\n".join(text) | |
# β Truncate Long Text for Model | |
def truncate_text(text, max_length=2048): | |
return text[:max_length] if len(text) > max_length else text | |
# β Answer Questions from Image or Document | |
def answer_question(file, question: str): | |
if isinstance(file, np.ndarray): # Image Processing | |
image = Image.fromarray(file) | |
caption = image_captioning_pipeline(image)[0]['generated_text'] | |
response = qa_pipeline(f"Question: {question}\nContext: {caption}") | |
return response[0]["generated_text"] | |
validation_error = validate_file_type(file) | |
if validation_error: | |
return validation_error | |
file_ext = file.name.split(".")[-1].lower() | |
# Extract Text from Supported Documents | |
if file_ext == "pdf": | |
text = extract_text_from_pdf(file) | |
elif file_ext in ["docx", "pptx"]: | |
text = extract_text_with_tika(file) | |
elif file_ext == "xlsx": | |
text = extract_text_from_excel(file) | |
else: | |
return "β Unsupported file format!" | |
if not text: | |
return "β οΈ No text extracted from the document." | |
truncated_text = truncate_text(text) | |
response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}") | |
return response[0]["generated_text"] | |
# β Gradio Interface (Separate File & Image Inputs) | |
with gr.Blocks() as demo: | |
gr.Markdown("## π AI-Powered Document & Image QA") | |
with gr.Row(): | |
file_input = gr.File(label="Upload Document") | |
image_input = gr.Image(label="Upload Image") | |
question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?") | |
answer_output = gr.Textbox(label="Answer") | |
submit_btn = gr.Button("Get Answer") | |
submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output) | |
# β Mount Gradio with FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
def home(): | |
return RedirectResponse(url="/") | |