qtAnswering / app.py
ikraamkb's picture
Update app.py
e2fade1 verified
raw
history blame
6.12 kB
"""import gradio as gr
import numpy as np
import fitz # PyMuPDF
import torch
import asyncio
from fastapi import FastAPI
from transformers import pipeline
from PIL import Image
from starlette.responses import RedirectResponse
from openpyxl import load_workbook
from docx import Document
from pptx import Presentation
# Initialize FastAPI
app = FastAPI()
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"βœ… Using device: {device}")
# Function to load models lazily
def get_qa_pipeline():
print("πŸ”„ Loading QA pipeline model...")
return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device, torch_dtype=torch.float16)
def get_image_captioning_pipeline():
print("πŸ”„ Loading Image Captioning model...")
return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
MAX_INPUT_LENGTH = 1024 # Limit input length for faster processing
# βœ… Validate File Type
def validate_file_type(file):
if hasattr(file, "name"):
ext = file.name.split(".")[-1].lower()
print(f"πŸ“ File extension detected: {ext}")
if ext not in ALLOWED_EXTENSIONS:
print(f"❌ Unsupported file format: {ext}")
return f"❌ Unsupported file format: {ext}"
return None
print("❌ Invalid file format!")
return "❌ Invalid file format!"
# βœ… Extract Text from PDF
async def extract_text_from_pdf(file):
print(f"πŸ“„ Extracting text from PDF: {file.name}")
loop = asyncio.get_event_loop()
text = await loop.run_in_executor(None, lambda: "\n".join([page.get_text() for page in fitz.open(file.name)]))
print(f"βœ… Extracted {len(text)} characters from PDF")
return text
# βœ… Extract Text from DOCX
async def extract_text_from_docx(file):
print(f"πŸ“„ Extracting text from DOCX: {file.name}")
loop = asyncio.get_event_loop()
text = await loop.run_in_executor(None, lambda: "\n".join([p.text for p in Document(file).paragraphs]))
print(f"βœ… Extracted {len(text)} characters from DOCX")
return text
# βœ… Extract Text from PPTX
async def extract_text_from_pptx(file):
print(f"πŸ“„ Extracting text from PPTX: {file.name}")
loop = asyncio.get_event_loop()
text = await loop.run_in_executor(None, lambda: "\n".join([shape.text for slide in Presentation(file).slides for shape in slide.shapes if hasattr(shape, "text")]))
print(f"βœ… Extracted {len(text)} characters from PPTX")
return text
# βœ… Extract Text from Excel
async def extract_text_from_excel(file):
print(f"πŸ“„ Extracting text from Excel: {file.name}")
loop = asyncio.get_event_loop()
text = await loop.run_in_executor(None, lambda: "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in load_workbook(file.name, data_only=True).worksheets for row in sheet.iter_rows(values_only=True)]))
print(f"βœ… Extracted {len(text)} characters from Excel")
return text
# βœ… Truncate Long Text
def truncate_text(text):
print(f"βœ‚οΈ Truncating text to {MAX_INPUT_LENGTH} characters (if needed)...")
return text[:MAX_INPUT_LENGTH] if len(text) > MAX_INPUT_LENGTH else text
# βœ… Answer Questions from Image or Document
async def answer_question(file, question: str):
print(f"❓ Question received: {question}")
if isinstance(file, np.ndarray): # Image Processing
print("πŸ–ΌοΈ Processing image for captioning...")
image = Image.fromarray(file)
image_captioning = get_image_captioning_pipeline()
caption = image_captioning(image)[0]['generated_text']
print(f"πŸ“ Generated caption: {caption}")
qa = get_qa_pipeline()
print("πŸ€– Running QA model...")
response = qa(f"Question: {question}\nContext: {caption}")
print(f"βœ… Model response: {response[0]['generated_text']}")
return response[0]["generated_text"]
validation_error = validate_file_type(file)
if validation_error:
return validation_error
file_ext = file.name.split(".")[-1].lower()
# Extract text asynchronously
if file_ext == "pdf":
text = await extract_text_from_pdf(file)
elif file_ext == "docx":
text = await extract_text_from_docx(file)
elif file_ext == "pptx":
text = await extract_text_from_pptx(file)
elif file_ext == "xlsx":
text = await extract_text_from_excel(file)
else:
print("❌ Unsupported file format!")
return "❌ Unsupported file format!"
if not text:
print("⚠️ No text extracted from the document.")
return "⚠️ No text extracted from the document."
truncated_text = truncate_text(text)
# Run QA model asynchronously
print("πŸ€– Running QA model...")
loop = asyncio.get_event_loop()
qa = get_qa_pipeline()
response = await loop.run_in_executor(None, qa, f"Question: {question}\nContext: {truncated_text}")
print(f"βœ… Model response: {response[0]['generated_text']}")
return response[0]["generated_text"]
# βœ… Gradio Interface (Separate File & Image Inputs)
with gr.Blocks() as demo:
gr.Markdown("## πŸ“„ AI-Powered Document & Image QA")
with gr.Row():
file_input = gr.File(label="Upload Document")
image_input = gr.Image(label="Upload Image")
question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
answer_output = gr.Textbox(label="Answer")
submit_btn = gr.Button("Get Answer")
submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
# βœ… Mount Gradio with FastAPI
app = gr.mount_gradio_app(app, demo, path="/")
@app.get("/")
def home():
return RedirectResponse(url="/")
"""
import torch
print("CUDA Available:", torch.cuda.is_available())
print("Torch Device Count:", torch.cuda.device_count())
print("Current Device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
print("CUDA Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")