Spaces:
Running
Running
"""import gradio as gr | |
import numpy as np | |
import fitz # PyMuPDF | |
import torch | |
import asyncio | |
from fastapi import FastAPI | |
from transformers import pipeline | |
from PIL import Image | |
from starlette.responses import RedirectResponse | |
from openpyxl import load_workbook | |
from docx import Document | |
from pptx import Presentation | |
# Initialize FastAPI | |
app = FastAPI() | |
# Use GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"β Using device: {device}") | |
# Function to load models lazily | |
def get_qa_pipeline(): | |
print("π Loading QA pipeline model...") | |
return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device, torch_dtype=torch.float16) | |
def get_image_captioning_pipeline(): | |
print("π Loading Image Captioning model...") | |
return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"} | |
MAX_INPUT_LENGTH = 1024 # Limit input length for faster processing | |
# β Validate File Type | |
def validate_file_type(file): | |
if hasattr(file, "name"): | |
ext = file.name.split(".")[-1].lower() | |
print(f"π File extension detected: {ext}") | |
if ext not in ALLOWED_EXTENSIONS: | |
print(f"β Unsupported file format: {ext}") | |
return f"β Unsupported file format: {ext}" | |
return None | |
print("β Invalid file format!") | |
return "β Invalid file format!" | |
# β Extract Text from PDF | |
async def extract_text_from_pdf(file): | |
print(f"π Extracting text from PDF: {file.name}") | |
loop = asyncio.get_event_loop() | |
text = await loop.run_in_executor(None, lambda: "\n".join([page.get_text() for page in fitz.open(file.name)])) | |
print(f"β Extracted {len(text)} characters from PDF") | |
return text | |
# β Extract Text from DOCX | |
async def extract_text_from_docx(file): | |
print(f"π Extracting text from DOCX: {file.name}") | |
loop = asyncio.get_event_loop() | |
text = await loop.run_in_executor(None, lambda: "\n".join([p.text for p in Document(file).paragraphs])) | |
print(f"β Extracted {len(text)} characters from DOCX") | |
return text | |
# β Extract Text from PPTX | |
async def extract_text_from_pptx(file): | |
print(f"π Extracting text from PPTX: {file.name}") | |
loop = asyncio.get_event_loop() | |
text = await loop.run_in_executor(None, lambda: "\n".join([shape.text for slide in Presentation(file).slides for shape in slide.shapes if hasattr(shape, "text")])) | |
print(f"β Extracted {len(text)} characters from PPTX") | |
return text | |
# β Extract Text from Excel | |
async def extract_text_from_excel(file): | |
print(f"π Extracting text from Excel: {file.name}") | |
loop = asyncio.get_event_loop() | |
text = await loop.run_in_executor(None, lambda: "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in load_workbook(file.name, data_only=True).worksheets for row in sheet.iter_rows(values_only=True)])) | |
print(f"β Extracted {len(text)} characters from Excel") | |
return text | |
# β Truncate Long Text | |
def truncate_text(text): | |
print(f"βοΈ Truncating text to {MAX_INPUT_LENGTH} characters (if needed)...") | |
return text[:MAX_INPUT_LENGTH] if len(text) > MAX_INPUT_LENGTH else text | |
# β Answer Questions from Image or Document | |
async def answer_question(file, question: str): | |
print(f"β Question received: {question}") | |
if isinstance(file, np.ndarray): # Image Processing | |
print("πΌοΈ Processing image for captioning...") | |
image = Image.fromarray(file) | |
image_captioning = get_image_captioning_pipeline() | |
caption = image_captioning(image)[0]['generated_text'] | |
print(f"π Generated caption: {caption}") | |
qa = get_qa_pipeline() | |
print("π€ Running QA model...") | |
response = qa(f"Question: {question}\nContext: {caption}") | |
print(f"β Model response: {response[0]['generated_text']}") | |
return response[0]["generated_text"] | |
validation_error = validate_file_type(file) | |
if validation_error: | |
return validation_error | |
file_ext = file.name.split(".")[-1].lower() | |
# Extract text asynchronously | |
if file_ext == "pdf": | |
text = await extract_text_from_pdf(file) | |
elif file_ext == "docx": | |
text = await extract_text_from_docx(file) | |
elif file_ext == "pptx": | |
text = await extract_text_from_pptx(file) | |
elif file_ext == "xlsx": | |
text = await extract_text_from_excel(file) | |
else: | |
print("β Unsupported file format!") | |
return "β Unsupported file format!" | |
if not text: | |
print("β οΈ No text extracted from the document.") | |
return "β οΈ No text extracted from the document." | |
truncated_text = truncate_text(text) | |
# Run QA model asynchronously | |
print("π€ Running QA model...") | |
loop = asyncio.get_event_loop() | |
qa = get_qa_pipeline() | |
response = await loop.run_in_executor(None, qa, f"Question: {question}\nContext: {truncated_text}") | |
print(f"β Model response: {response[0]['generated_text']}") | |
return response[0]["generated_text"] | |
# β Gradio Interface (Separate File & Image Inputs) | |
with gr.Blocks() as demo: | |
gr.Markdown("## π AI-Powered Document & Image QA") | |
with gr.Row(): | |
file_input = gr.File(label="Upload Document") | |
image_input = gr.Image(label="Upload Image") | |
question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?") | |
answer_output = gr.Textbox(label="Answer") | |
submit_btn = gr.Button("Get Answer") | |
submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output) | |
# β Mount Gradio with FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
@app.get("/") | |
def home(): | |
return RedirectResponse(url="/") | |
""" | |
from fastapi import FastAPI, Form, File, UploadFile | |
from fastapi.responses import RedirectResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from transformers import pipeline | |
import os | |
from PIL import Image | |
import io | |
import pdfplumber | |
import docx | |
import openpyxl | |
import pytesseract | |
from io import BytesIO | |
import fitz # PyMuPDF | |
import easyocr | |
from fastapi.templating import Jinja2Templates | |
from starlette.requests import Request | |
# Initialize the app | |
app = FastAPI() | |
# Mount the static directory to serve HTML, CSS, JS files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Initialize transformers pipelines | |
qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2") | |
image_qa_pipeline = pipeline("image-question-answering", model="Salesforce/blip-vqa-base", tokenizer="Salesforce/blip-vqa-base") | |
# Initialize EasyOCR for image-based text extraction | |
reader = easyocr.Reader(['en']) | |
# Define a template for rendering HTML | |
templates = Jinja2Templates(directory="templates") | |
# Function to process PDFs | |
def extract_pdf_text(file_path: str): | |
with pdfplumber.open(file_path) as pdf: | |
text = "" | |
for page in pdf.pages: | |
text += page.extract_text() | |
return text | |
# Function to process DOCX files | |
def extract_docx_text(file_path: str): | |
doc = docx.Document(file_path) | |
text = "" | |
for para in doc.paragraphs: | |
text += para.text | |
return text | |
# Function to process PPTX files | |
def extract_pptx_text(file_path: str): | |
from pptx import Presentation | |
prs = Presentation(file_path) | |
text = "" | |
for slide in prs.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text += shape.text | |
return text | |
# Function to extract text from images using OCR | |
def extract_text_from_image(image: Image): | |
text = pytesseract.image_to_string(image) | |
return text | |
# Home route | |
def home(): | |
return RedirectResponse(url="/docs") | |
# Function to answer questions based on document content | |
async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)): | |
# Save the uploaded file temporarily | |
file_path = f"temp_files/{file.filename}" | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
with open(file_path, "wb") as f: | |
f.write(await file.read()) | |
# Extract text based on file type | |
if file.filename.endswith(".pdf"): | |
text = extract_pdf_text(file_path) | |
elif file.filename.endswith(".docx"): | |
text = extract_docx_text(file_path) | |
elif file.filename.endswith(".pptx"): | |
text = extract_pptx_text(file_path) | |
else: | |
return {"error": "Unsupported file format"} | |
# Use the model for question answering | |
qa_result = qa_pipeline(question=question, context=text) | |
return {"answer": qa_result['answer']} | |
# Function to answer questions based on images | |
async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)): | |
# Open the uploaded image | |
image = Image.open(BytesIO(await image_file.read())) | |
# Use EasyOCR to extract text if the image has textual content | |
image_text = extract_text_from_image(image) | |
# Use the BLIP VQA model for question answering on the image | |
image_qa_result = image_qa_pipeline(image=image, question=question) | |
return {"answer": image_qa_result['answer'], "image_text": image_text} | |
# Serve the application in Hugging Face space | |
async def get_docs(request: Request): | |
return templates.TemplateResponse("static/index.html", {"request": request}) | |