Spaces:
Running
Running
"""from fastapi import FastAPI | |
from fastapi.responses import RedirectResponse | |
import gradio as gr | |
from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM | |
from PIL import Image | |
import torch | |
import fitz # PyMuPDF for PDF | |
app = FastAPI() | |
# ========== Image QA Setup ========== | |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
def answer_question_from_image(image, question): | |
if image is None or not question.strip(): | |
return "Please upload an image and ask a question." | |
inputs = vqa_processor(image, question, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vqa_model(**inputs) | |
predicted_id = outputs.logits.argmax(-1).item() | |
return vqa_model.config.id2label[predicted_id] | |
# ========== Gradio Interfaces ========== | |
img_interface = gr.Interface( | |
fn=answer_question_from_image, | |
inputs=[gr.Image(label="Upload Image"), gr.Textbox(label="Ask a Question")], | |
outputs="text", | |
title="Image Question Answering" | |
) | |
# ========== Combine and Mount ========== | |
demo = gr.TabbedInterface( img_interface , "Image QA") | |
app = gr.mount_gradio_app(app, demo, path="/") | |
@app.get("/") | |
def root(): | |
return RedirectResponse(url="/") """ | |
"""from transformers import ViltProcessor, ViltForQuestionAnswering | |
import torch | |
# Load image QA model once | |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
def answer_question_from_image(image, question): | |
if image is None or not question.strip(): | |
return "Please upload an image and ask a question." | |
inputs = vqa_processor(image, question, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vqa_model(**inputs) | |
predicted_id = outputs.logits.argmax(-1).item() | |
return vqa_model.config.id2label[predicted_id]""" | |
from fastapi import FastAPI, Request, UploadFile, Form | |
from fastapi.responses import RedirectResponse, FileResponse, HTMLResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
import os | |
import shutil | |
from PIL import Image | |
from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline | |
from gtts import gTTS | |
import easyocr | |
import torch | |
import tempfile | |
import gradio as gr | |
import numpy as np | |
app = FastAPI() | |
# Setup templates and static | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
app.mount("/resources", StaticFiles(directory="resources"), name="resources") | |
templates = Jinja2Templates(directory="templates") | |
# Serve custom HTML at / | |
def serve_home(request: Request): | |
return templates.TemplateResponse("home.html", {"request": request}) | |
# Load Models | |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
reader = easyocr.Reader(['en', 'fr']) | |
# Determine which feature to use | |
def classify_question(question: str): | |
question_lower = question.lower() | |
if any(word in question_lower for word in ["text", "say", "written", "read"]): | |
return "ocr" | |
elif any(word in question_lower for word in ["caption", "describe", "what is in the image"]): | |
return "caption" | |
else: | |
return "vqa" | |
# Answer logic | |
def answer_question_from_image(image, question): | |
if image is None or not question.strip(): | |
return "Please upload an image and ask a question.", None | |
mode = classify_question(question) | |
if mode == "ocr": | |
try: | |
result = reader.readtext(np.array(image)) | |
text = " ".join([entry[1] for entry in result]) | |
answer = text.strip() or "No readable text found." | |
except Exception as e: | |
answer = f"OCR Error: {e}" | |
elif mode == "caption": | |
try: | |
answer = captioner(image)[0]['generated_text'] | |
except Exception as e: | |
answer = f"Captioning error: {e}" | |
else: | |
try: | |
inputs = vqa_processor(image, question, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vqa_model(**inputs) | |
predicted_id = outputs.logits.argmax(-1).item() | |
answer = vqa_model.config.id2label[predicted_id] | |
except Exception as e: | |
answer = f"VQA error: {e}" | |
try: | |
tts = gTTS(text=answer) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: | |
tts.save(tmp.name) | |
audio_path = tmp.name | |
except Exception as e: | |
return f"Answer: {answer}\n\n⚠️ Audio generation error: {e}", None | |
return answer, audio_path | |
# API Endpoint for frontend | |
async def predict(file: UploadFile = Form(...), question: str = Form(...)): | |
try: | |
file_ext = file.filename.split(".")[-1].lower() | |
image = Image.open(file.file) | |
answer, audio_path = answer_question_from_image(image, question) | |
return JSONResponse({ | |
"answer": answer, | |
"audio": f"/audio/{os.path.basename(audio_path)}" if audio_path else None | |
}) | |
except Exception as e: | |
return JSONResponse({"error": f"Server error: {e}"}, status_code=500) | |
# Serve audio responses | |
def serve_audio(filename: str): | |
audio_path = os.path.join(tempfile.gettempdir(), filename) | |
if os.path.exists(audio_path): | |
return FileResponse(audio_path, media_type="audio/mpeg") | |
return JSONResponse({"error": "File not found"}, status_code=404) | |