ikraamkb commited on
Commit
c255de1
·
verified ·
1 Parent(s): 9e9ecd2

Update appImage.py

Browse files
Files changed (1) hide show
  1. appImage.py +23 -35
appImage.py CHANGED
@@ -49,40 +49,26 @@ def answer_question_from_image(image, question):
49
 
50
  predicted_id = outputs.logits.argmax(-1).item()
51
  return vqa_model.config.id2label[predicted_id]"""
52
- from fastapi import FastAPI, Request, UploadFile, Form
53
- from fastapi.responses import RedirectResponse, FileResponse, HTMLResponse, JSONResponse
54
- from fastapi.staticfiles import StaticFiles
55
- from fastapi.templating import Jinja2Templates
56
  import os
57
- import shutil
58
  from PIL import Image
59
  from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline
60
  from gtts import gTTS
61
  import easyocr
62
  import torch
63
  import tempfile
64
- import gradio as gr
65
  import numpy as np
 
66
 
67
  app = FastAPI()
68
 
69
- # Setup templates and static
70
- app.mount("/static", StaticFiles(directory="static"), name="static")
71
- app.mount("/resources", StaticFiles(directory="resources"), name="resources")
72
- templates = Jinja2Templates(directory="templates")
73
-
74
- # Serve custom HTML at /
75
- @app.get("/", response_class=HTMLResponse)
76
- def serve_home(request: Request):
77
- return templates.TemplateResponse("home.html", {"request": request})
78
-
79
- # Load Models
80
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
81
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
82
  captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
83
- reader = easyocr.Reader(['en'], gpu=False, download_enabled=True)
84
 
85
- # Determine which feature to use
86
  def classify_question(question: str):
87
  question_lower = question.lower()
88
  if any(word in question_lower for word in ["text", "say", "written", "read"]):
@@ -92,7 +78,6 @@ def classify_question(question: str):
92
  else:
93
  return "vqa"
94
 
95
- # Answer logic
96
  def answer_question_from_image(image, question):
97
  if image is None or not question.strip():
98
  return "Please upload an image and ask a question.", None
@@ -106,11 +91,13 @@ def answer_question_from_image(image, question):
106
  answer = text.strip() or "No readable text found."
107
  except Exception as e:
108
  answer = f"OCR Error: {e}"
 
109
  elif mode == "caption":
110
  try:
111
  answer = captioner(image)[0]['generated_text']
112
  except Exception as e:
113
  answer = f"Captioning error: {e}"
 
114
  else:
115
  try:
116
  inputs = vqa_processor(image, question, return_tensors="pt")
@@ -131,25 +118,26 @@ def answer_question_from_image(image, question):
131
 
132
  return answer, audio_path
133
 
134
- # API Endpoint for frontend
135
  @app.post("/predict")
136
- async def predict(file: UploadFile = Form(...), question: str = Form(...)):
137
  try:
138
- file_ext = file.filename.split(".")[-1].lower()
139
- image = Image.open(file.file)
140
  answer, audio_path = answer_question_from_image(image, question)
141
 
142
- return JSONResponse({
143
- "answer": answer,
144
- "audio": f"/audio/{os.path.basename(audio_path)}" if audio_path else None
145
- })
 
146
  except Exception as e:
147
- return JSONResponse({"error": f"Server error: {e}"}, status_code=500)
148
 
149
- # Serve audio responses
150
  @app.get("/audio/{filename}")
151
- def serve_audio(filename: str):
152
- audio_path = os.path.join(tempfile.gettempdir(), filename)
153
- if os.path.exists(audio_path):
154
- return FileResponse(audio_path, media_type="audio/mpeg")
155
- return JSONResponse({"error": "File not found"}, status_code=404)
 
 
 
49
 
50
  predicted_id = outputs.logits.argmax(-1).item()
51
  return vqa_model.config.id2label[predicted_id]"""
52
+ from fastapi import FastAPI, UploadFile, Form
53
+ from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
 
 
54
  import os
 
55
  from PIL import Image
56
  from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline
57
  from gtts import gTTS
58
  import easyocr
59
  import torch
60
  import tempfile
 
61
  import numpy as np
62
+ from io import BytesIO
63
 
64
  app = FastAPI()
65
 
66
+ # Load models
 
 
 
 
 
 
 
 
 
 
67
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
68
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
69
  captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
70
+ reader = easyocr.Reader(['en', 'fr'])
71
 
 
72
  def classify_question(question: str):
73
  question_lower = question.lower()
74
  if any(word in question_lower for word in ["text", "say", "written", "read"]):
 
78
  else:
79
  return "vqa"
80
 
 
81
  def answer_question_from_image(image, question):
82
  if image is None or not question.strip():
83
  return "Please upload an image and ask a question.", None
 
91
  answer = text.strip() or "No readable text found."
92
  except Exception as e:
93
  answer = f"OCR Error: {e}"
94
+
95
  elif mode == "caption":
96
  try:
97
  answer = captioner(image)[0]['generated_text']
98
  except Exception as e:
99
  answer = f"Captioning error: {e}"
100
+
101
  else:
102
  try:
103
  inputs = vqa_processor(image, question, return_tensors="pt")
 
118
 
119
  return answer, audio_path
120
 
 
121
  @app.post("/predict")
122
+ async def predict(question: str = Form(...), file: UploadFile = Form(...)):
123
  try:
124
+ image_data = await file.read()
125
+ image = Image.open(BytesIO(image_data)).convert("RGB")
126
  answer, audio_path = answer_question_from_image(image, question)
127
 
128
+ if audio_path and os.path.exists(audio_path):
129
+ return JSONResponse({"answer": answer, "audio": f"/audio/{os.path.basename(audio_path)}"})
130
+ else:
131
+ return JSONResponse({"answer": answer})
132
+
133
  except Exception as e:
134
+ return JSONResponse({"error": str(e)})
135
 
 
136
  @app.get("/audio/{filename}")
137
+ async def get_audio(filename: str):
138
+ filepath = os.path.join(tempfile.gettempdir(), filename)
139
+ return FileResponse(filepath, media_type="audio/mpeg")
140
+
141
+ @app.get("/")
142
+ def home():
143
+ return RedirectResponse(url="/static/home.html")