from fastapi import FastAPI, File, UploadFile import fitz # PyMuPDF for PDF parsing from tika import parser # Apache Tika for document parsing import openpyxl from pptx import Presentation import torch from torchvision import transforms from torchvision.models.detection import fasterrcnn_resnet50_fpn from PIL import Image from transformers import pipeline import gradio as gr from fastapi.responses import RedirectResponse import numpy as np import easyocr # Initialize FastAPI print("🚀 FastAPI server is starting...") app = FastAPI() # Load AI Model for Question Answering (DeepSeek-V2-Chat) from transformers import AutoModelForCausalLM, AutoTokenizer # Preload Hugging Face model print(f"🔄 Loading models") qa_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=-1) # Load Pretrained Object Detection Model (Torchvision) from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT model = fasterrcnn_resnet50_fpn(weights=weights) model.eval() # Initialize OCR Model (Lazy Load) print("🔄 Initializing OCR Model...") reader = easyocr.Reader(["en"], gpu=True) # Image Transformations transform = transforms.Compose([ transforms.ToTensor() ]) # Allowed File Extensions ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"} def validate_file_type(file): ext = file.name.split(".")[-1].lower() print(f"🔍 Validating file type: {ext}") if ext not in ALLOWED_EXTENSIONS: return f"❌ Unsupported file format: {ext}" return None # Function to truncate text to 450 tokens def truncate_text(text, max_tokens=450): words = text.split() truncated = " ".join(words[:max_tokens]) print(f"✂️ Truncated text to {max_tokens} tokens.") return truncated # Document Text Extraction Functions def extract_text_from_pdf(pdf_file): try: print("📄 Extracting text from PDF...") doc = fitz.open(pdf_file) text = "\n".join([page.get_text("text") for page in doc]) print("✅ PDF text extraction completed.") return text if text else "⚠️ No text found." except Exception as e: return f"❌ Error reading PDF: {str(e)}" def extract_text_with_tika(file): try: print("📝 Extracting text with Tika...") parsed = parser.from_buffer(file) print("✅ Tika text extraction completed.") return parsed.get("content", "⚠️ No text found.").strip() except Exception as e: return f"❌ Error reading document: {str(e)}" def extract_text_from_pptx(pptx_file): try: print("📊 Extracting text from PPTX...") ppt = Presentation(pptx_file) text = [] for slide in ppt.slides: for shape in slide.shapes: if hasattr(shape, "text"): text.append(shape.text) print("✅ PPTX text extraction completed.") return "\n".join(text) if text else "⚠️ No text found." except Exception as e: return f"❌ Error reading PPTX: {str(e)}" def extract_text_from_excel(excel_file): try: print("📊 Extracting text from Excel...") wb = openpyxl.load_workbook(excel_file, read_only=True) text = [] for sheet in wb.worksheets: for row in sheet.iter_rows(values_only=True): text.append(" ".join(map(str, row))) print("✅ Excel text extraction completed.") return "\n".join(text) if text else "⚠️ No text found." except Exception as e: return f"❌ Error reading Excel: {str(e)}" def extract_text_from_image(image_file): print("🖼️ Extracting text from image...") image = Image.open(image_file).convert("RGB") if np.array(image).std() < 10: print("⚠️ Low contrast detected. No meaningful content.") return "⚠️ No meaningful content detected in the image." result = reader.readtext(np.array(image)) print("✅ Image text extraction completed.") return " ".join([res[1] for res in result]) if result else "⚠️ No text found." def answer_question_from_document(file, question): print("📂 Processing document for QA...") validation_error = validate_file_type(file) if validation_error: return validation_error file_ext = file.name.split(".")[-1].lower() if file_ext == "pdf": text = extract_text_from_pdf(file) elif file_ext in ["docx", "pptx"]: text = extract_text_with_tika(file) elif file_ext == "xlsx": text = extract_text_from_excel(file) else: return "❌ Unsupported file format!" if not text: return "⚠️ No text extracted from the document." truncated_text = truncate_text(text) print("🤖 Generating response...") response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}") print("✅ AI response generated.") return response[0]["generated_text"] def answer_question_from_image(image, question): print("🖼️ Processing image for QA...") image_text = extract_text_from_image(image) if not image_text: return "⚠️ No meaningful content detected in the image." truncated_text = truncate_text(image_text) print("🤖 Generating response...") response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}") print("✅ AI response generated.") return response[0]["generated_text"] print("✅ Models loaded successfully.") doc_interface = gr.Interface(fn=answer_question_from_document, inputs=[gr.File(), gr.Textbox()], outputs="text") img_interface = gr.Interface(fn=answer_question_from_image, inputs=[gr.Image(), gr.Textbox()], outputs="text") demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"]) app = gr.mount_gradio_app(app, demo, path="/") @app.get("/") def home(): return RedirectResponse(url="/")