ikraamkb commited on
Commit
7abb419
·
verified ·
1 Parent(s): 9f09ae7

Update appImage.py

Browse files
Files changed (1) hide show
  1. appImage.py +13 -45
appImage.py CHANGED
@@ -49,7 +49,8 @@ def answer_question_from_image(image, question):
49
 
50
  predicted_id = outputs.logits.argmax(-1).item()
51
  return vqa_model.config.id2label[predicted_id]"""
52
- from fastapi import FastAPI, UploadFile, Form
 
53
  from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
54
  import os
55
  from PIL import Image
@@ -63,20 +64,18 @@ from io import BytesIO
63
 
64
  app = FastAPI()
65
 
66
- # Load models
67
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
68
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
69
  captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
70
  reader = easyocr.Reader(['en', 'fr'])
71
 
72
  def classify_question(question: str):
73
- question_lower = question.lower()
74
- if any(word in question_lower for word in ["text", "say", "written", "read"]):
75
  return "ocr"
76
- elif any(word in question_lower for word in ["caption", "describe", "what is in the image"]):
77
  return "caption"
78
- else:
79
- return "vqa"
80
 
81
  def answer_question_from_image(image, question):
82
  if image is None or not question.strip():
@@ -84,59 +83,28 @@ def answer_question_from_image(image, question):
84
 
85
  mode = classify_question(question)
86
 
87
- if mode == "ocr":
88
- try:
89
  result = reader.readtext(np.array(image))
90
- text = " ".join([entry[1] for entry in result])
91
- answer = text.strip() or "No readable text found."
92
- except Exception as e:
93
- answer = f"OCR Error: {e}"
94
 
95
- elif mode == "caption":
96
- try:
97
  answer = captioner(image)[0]['generated_text']
98
- except Exception as e:
99
- answer = f"Captioning error: {e}"
100
 
101
- else:
102
- try:
103
  inputs = vqa_processor(image, question, return_tensors="pt")
104
  with torch.no_grad():
105
  outputs = vqa_model(**inputs)
106
  predicted_id = outputs.logits.argmax(-1).item()
107
  answer = vqa_model.config.id2label[predicted_id]
108
- except Exception as e:
109
- answer = f"VQA error: {e}"
110
 
111
- try:
112
  tts = gTTS(text=answer)
113
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
114
  tts.save(tmp.name)
115
- audio_path = tmp.name
116
- except Exception as e:
117
- return f"Answer: {answer}\n\n⚠️ Audio generation error: {e}", None
118
-
119
- return answer, audio_path
120
-
121
- @app.post("/predict")
122
- async def predict(question: str = Form(...), file: UploadFile = Form(...)):
123
- try:
124
- image_data = await file.read()
125
- image = Image.open(BytesIO(image_data)).convert("RGB")
126
- answer, audio_path = answer_question_from_image(image, question)
127
-
128
- if audio_path and os.path.exists(audio_path):
129
- return JSONResponse({"answer": answer, "audio": f"/audio/{os.path.basename(audio_path)}"})
130
- else:
131
- return JSONResponse({"answer": answer})
132
 
133
  except Exception as e:
134
- return JSONResponse({"error": str(e)})
135
-
136
- @app.get("/audio/{filename}")
137
- async def get_audio(filename: str):
138
- filepath = os.path.join(tempfile.gettempdir(), filename)
139
- return FileResponse(filepath, media_type="audio/mpeg")
140
 
141
  @app.get("/")
142
  def home():
 
49
 
50
  predicted_id = outputs.logits.argmax(-1).item()
51
  return vqa_model.config.id2label[predicted_id]"""
52
+ ### appImage.py Image QA Backend (Cleaned)
53
+ from fastapi import FastAPI
54
  from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
55
  import os
56
  from PIL import Image
 
64
 
65
  app = FastAPI()
66
 
 
67
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
68
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
69
  captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
70
  reader = easyocr.Reader(['en', 'fr'])
71
 
72
  def classify_question(question: str):
73
+ q = question.lower()
74
+ if any(w in q for w in ["text", "say", "written", "read"]):
75
  return "ocr"
76
+ if any(w in q for w in ["caption", "describe", "what is in the image"]):
77
  return "caption"
78
+ return "vqa"
 
79
 
80
  def answer_question_from_image(image, question):
81
  if image is None or not question.strip():
 
83
 
84
  mode = classify_question(question)
85
 
86
+ try:
87
+ if mode == "ocr":
88
  result = reader.readtext(np.array(image))
89
+ answer = " ".join([entry[1] for entry in result]) or "No readable text found."
 
 
 
90
 
91
+ elif mode == "caption":
 
92
  answer = captioner(image)[0]['generated_text']
 
 
93
 
94
+ else:
 
95
  inputs = vqa_processor(image, question, return_tensors="pt")
96
  with torch.no_grad():
97
  outputs = vqa_model(**inputs)
98
  predicted_id = outputs.logits.argmax(-1).item()
99
  answer = vqa_model.config.id2label[predicted_id]
 
 
100
 
 
101
  tts = gTTS(text=answer)
102
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
103
  tts.save(tmp.name)
104
+ return answer, tmp.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  except Exception as e:
107
+ return f"Error: {e}", None
 
 
 
 
 
108
 
109
  @app.get("/")
110
  def home():