Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile | |
import fitz # PyMuPDF for PDF parsing | |
from tika import parser # Apache Tika for document parsing | |
import openpyxl | |
from pptx import Presentation | |
import torch | |
from torchvision import transforms | |
from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
from PIL import Image | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
from fastapi.responses import RedirectResponse | |
import numpy as np | |
import easyocr | |
import os | |
# Initialize FastAPI | |
app = FastAPI() | |
print(f"π Loading models") | |
model_name = "mistralai/Mistral-7B-v0.1" | |
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") # Fetch token from environment variable | |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token) | |
doc_qa_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1) | |
# Load Image Captioning Model (nlpconnect/vit-gpt2-image-captioning) | |
image_captioning_pipeline = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
# Initialize OCR Model (Lazy Load) | |
reader = easyocr.Reader(["en"], gpu=True) | |
# Allowed File Extensions | |
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"} | |
def validate_file_type(file): | |
ext = file.name.split(".")[-1].lower() | |
print(f"π Validating file type: {ext}") | |
if ext not in ALLOWED_EXTENSIONS: | |
return f"β Unsupported file format: {ext}" | |
return None | |
# Function to truncate text to 450 tokens | |
def truncate_text(text, max_tokens=450): | |
words = text.split() | |
truncated = " ".join(words[:max_tokens]) | |
print(f"βοΈ Truncated text to {max_tokens} tokens.") | |
return truncated | |
# Document Text Extraction Functions | |
def extract_text_from_pdf(pdf_file): | |
try: | |
print("π Extracting text from PDF...") | |
doc = fitz.open(pdf_file) | |
text = "\n".join([page.get_text("text") for page in doc]) | |
return text if text else "β οΈ No text found." | |
except Exception as e: | |
return f"β Error reading PDF: {str(e)}" | |
def extract_text_with_tika(file): | |
try: | |
print("π Extracting text with Tika...") | |
parsed = parser.from_buffer(file) | |
return parsed.get("content", "β οΈ No text found.").strip() | |
except Exception as e: | |
return f"β Error reading document: {str(e)}" | |
def extract_text_from_pptx(pptx_file): | |
try: | |
print("π Extracting text from PPTX...") | |
ppt = Presentation(pptx_file) | |
text = [] | |
for slide in ppt.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text.append(shape.text) | |
return "\n".join(text) if text else "β οΈ No text found." | |
except Exception as e: | |
return f"β Error reading PPTX: {str(e)}" | |
def extract_text_from_excel(excel_file): | |
try: | |
print("π Extracting text from Excel...") | |
wb = openpyxl.load_workbook(excel_file, read_only=True) | |
text = [] | |
for sheet in wb.worksheets: | |
for row in sheet.iter_rows(values_only=True): | |
text.append(" ".join(map(str, row))) | |
return "\n".join(text) if text else "β οΈ No text found." | |
except Exception as e: | |
return f"β Error reading Excel: {str(e)}" | |
def answer_question_from_document(file, question): | |
print("π Processing document for QA...") | |
validation_error = validate_file_type(file) | |
if validation_error: | |
return validation_error | |
file_ext = file.name.split(".")[-1].lower() | |
if file_ext == "pdf": | |
text = extract_text_from_pdf(file) | |
elif file_ext in ["docx", "pptx"]: | |
text = extract_text_with_tika(file) | |
elif file_ext == "xlsx": | |
text = extract_text_from_excel(file) | |
else: | |
return "β Unsupported file format!" | |
if not text: | |
return "β οΈ No text extracted from the document." | |
truncated_text = truncate_text(text) | |
print("π€ Generating response...") | |
response = doc_qa_pipeline(f"Question: {question}\nContext: {truncated_text}") | |
return response[0]["generated_text"] | |
def answer_question_from_image(image, question): | |
print("πΌοΈ Generating caption for image...") | |
caption = image_captioning_pipeline(image)[0]['generated_text'] | |
print("π€ Answering question based on caption...") | |
response = doc_qa_pipeline(f"Question: {question}\nContext: {caption}") | |
return response[0]["generated_text"] | |
# Gradio UI for Document & Image QA | |
doc_interface = gr.Interface( | |
fn=answer_question_from_document, | |
inputs=[gr.File(label="π Upload Document"), gr.Textbox(label="π¬ Ask a Question")], | |
outputs="text", | |
title="π AI Document Question Answering" | |
) | |
img_interface = gr.Interface( | |
fn=answer_question_from_image, | |
inputs=[gr.Image(label="πΌοΈ Upload Image"), gr.Textbox(label="π¬ Ask a Question")], | |
outputs="text", | |
title="πΌοΈ AI Image Question Answering" | |
) | |
# Mount Gradio Interfaces | |
demo = gr.TabbedInterface([doc_interface, img_interface], ["π Document QA", "πΌοΈ Image QA"]) | |
app = gr.mount_gradio_app(app, demo, path="/") | |
def home(): | |
return RedirectResponse(url="/") | |