from fastapi import FastAPI, File, UploadFile import fitz # PyMuPDF for PDF parsing from tika import parser # Apache Tika for document parsing import openpyxl from pptx import Presentation import torch from torchvision import transforms from torchvision.models.detection import fasterrcnn_resnet50_fpn from PIL import Image from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import gradio as gr from fastapi.responses import RedirectResponse import numpy as np import easyocr import os # Initialize FastAPI app = FastAPI() print(f"🔄 Loading models") model_name = "mistralai/Mistral-7B-Instruct" hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") # Fetch token from environment variable tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token) doc_qa_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1) # Load Image Captioning Model (nlpconnect/vit-gpt2-image-captioning) image_captioning_pipeline = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") # Initialize OCR Model (Lazy Load) reader = easyocr.Reader(["en"], gpu=True) # Allowed File Extensions ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"} def validate_file_type(file): ext = file.name.split(".")[-1].lower() print(f"🔍 Validating file type: {ext}") if ext not in ALLOWED_EXTENSIONS: return f"❌ Unsupported file format: {ext}" return None # Function to truncate text to 450 tokens def truncate_text(text, max_tokens=450): words = text.split() truncated = " ".join(words[:max_tokens]) print(f"✂️ Truncated text to {max_tokens} tokens.") return truncated # Document Text Extraction Functions def extract_text_from_pdf(pdf_file): try: print("📄 Extracting text from PDF...") doc = fitz.open(pdf_file) text = "\n".join([page.get_text("text") for page in doc]) return text if text else "⚠️ No text found." except Exception as e: return f"❌ Error reading PDF: {str(e)}" def extract_text_with_tika(file): try: print("📝 Extracting text with Tika...") parsed = parser.from_buffer(file) return parsed.get("content", "⚠️ No text found.").strip() except Exception as e: return f"❌ Error reading document: {str(e)}" def extract_text_from_pptx(pptx_file): try: print("📊 Extracting text from PPTX...") ppt = Presentation(pptx_file) text = [] for slide in ppt.slides: for shape in slide.shapes: if hasattr(shape, "text"): text.append(shape.text) return "\n".join(text) if text else "⚠️ No text found." except Exception as e: return f"❌ Error reading PPTX: {str(e)}" def extract_text_from_excel(excel_file): try: print("📊 Extracting text from Excel...") wb = openpyxl.load_workbook(excel_file, read_only=True) text = [] for sheet in wb.worksheets: for row in sheet.iter_rows(values_only=True): text.append(" ".join(map(str, row))) return "\n".join(text) if text else "⚠️ No text found." except Exception as e: return f"❌ Error reading Excel: {str(e)}" def answer_question_from_document(file, question): print("📂 Processing document for QA...") validation_error = validate_file_type(file) if validation_error: return validation_error file_ext = file.name.split(".")[-1].lower() if file_ext == "pdf": text = extract_text_from_pdf(file) elif file_ext in ["docx", "pptx"]: text = extract_text_with_tika(file) elif file_ext == "xlsx": text = extract_text_from_excel(file) else: return "❌ Unsupported file format!" if not text: return "⚠️ No text extracted from the document." truncated_text = truncate_text(text) print("🤖 Generating response...") response = doc_qa_pipeline(f"Question: {question}\nContext: {truncated_text}") return response[0]["generated_text"] def answer_question_from_image(image, question): print("🖼️ Generating caption for image...") caption = image_captioning_pipeline(image)[0]['generated_text'] print("🤖 Answering question based on caption...") response = doc_qa_pipeline(f"Question: {question}\nContext: {caption}") return response[0]["generated_text"] # Gradio UI for Document & Image QA doc_interface = gr.Interface( fn=answer_question_from_document, inputs=[gr.File(label="📂 Upload Document"), gr.Textbox(label="💬 Ask a Question")], outputs="text", title="📄 AI Document Question Answering" ) img_interface = gr.Interface( fn=answer_question_from_image, inputs=[gr.Image(label="🖼️ Upload Image"), gr.Textbox(label="💬 Ask a Question")], outputs="text", title="🖼️ AI Image Question Answering" ) # Mount Gradio Interfaces demo = gr.TabbedInterface([doc_interface, img_interface], ["📄 Document QA", "🖼️ Image QA"]) app = gr.mount_gradio_app(app, demo, path="/") @app.get("/") def home(): return RedirectResponse(url="/")