Spaces:
Running
Running
Update appImage.py
Browse files- appImage.py +28 -13
appImage.py
CHANGED
@@ -63,10 +63,7 @@ from fastapi import FastAPI
|
|
63 |
from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
|
64 |
import os
|
65 |
from PIL import Image
|
66 |
-
from transformers import
|
67 |
-
ViltProcessor, ViltForQuestionAnswering,
|
68 |
-
AutoProcessor, GitForCausalLM
|
69 |
-
)
|
70 |
from gtts import gTTS
|
71 |
import easyocr
|
72 |
import torch
|
@@ -76,14 +73,22 @@ from io import BytesIO
|
|
76 |
|
77 |
app = FastAPI()
|
78 |
|
79 |
-
#
|
80 |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
81 |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
def classify_question(question: str):
|
|
|
87 |
q = question.lower()
|
88 |
if any(w in q for w in ["text", "say", "written", "read"]):
|
89 |
return "ocr"
|
@@ -91,6 +96,17 @@ def classify_question(question: str):
|
|
91 |
return "caption"
|
92 |
return "vqa"
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def answer_question_from_image(image, question):
|
95 |
if image is None or not question.strip():
|
96 |
return "Please upload an image and ask a question.", None
|
@@ -103,17 +119,16 @@ def answer_question_from_image(image, question):
|
|
103 |
answer = " ".join([entry[1] for entry in result]) or "No readable text found."
|
104 |
|
105 |
elif mode == "caption":
|
106 |
-
|
107 |
-
generated_ids = caption_model.generate(image_tensor, max_new_tokens=64)
|
108 |
-
answer = caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
109 |
|
110 |
-
else:
|
111 |
inputs = vqa_processor(image, question, return_tensors="pt")
|
112 |
with torch.no_grad():
|
113 |
outputs = vqa_model(**inputs)
|
114 |
predicted_id = outputs.logits.argmax(-1).item()
|
115 |
answer = vqa_model.config.id2label[predicted_id]
|
116 |
|
|
|
117 |
tts = gTTS(text=answer)
|
118 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
|
119 |
tts.save(tmp.name)
|
@@ -124,4 +139,4 @@ def answer_question_from_image(image, question):
|
|
124 |
|
125 |
@app.get("/")
|
126 |
def home():
|
127 |
-
return RedirectResponse(url="/templates/home.html")
|
|
|
63 |
from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
|
64 |
import os
|
65 |
from PIL import Image
|
66 |
+
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoProcessor, AutoModelForCausalLM
|
|
|
|
|
|
|
67 |
from gtts import gTTS
|
68 |
import easyocr
|
69 |
import torch
|
|
|
73 |
|
74 |
app = FastAPI()
|
75 |
|
76 |
+
# Initialize models with optimized settings
|
77 |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
78 |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
79 |
+
|
80 |
+
# Load GIT model with performance optimizations
|
81 |
+
git_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
|
82 |
+
git_model = AutoModelForCausalLM.from_pretrained(
|
83 |
+
"microsoft/git-large-coco",
|
84 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
85 |
+
device_map="auto"
|
86 |
+
)
|
87 |
+
|
88 |
+
reader = easyocr.Reader(['en', 'fr'], gpu=torch.cuda.is_available())
|
89 |
|
90 |
def classify_question(question: str):
|
91 |
+
"""Optimized question classification"""
|
92 |
q = question.lower()
|
93 |
if any(w in q for w in ["text", "say", "written", "read"]):
|
94 |
return "ocr"
|
|
|
96 |
return "caption"
|
97 |
return "vqa"
|
98 |
|
99 |
+
@torch.inference_mode()
|
100 |
+
def generate_caption(image):
|
101 |
+
"""Optimized caption generation with GIT model"""
|
102 |
+
try:
|
103 |
+
inputs = git_processor(images=image, return_tensors="pt").to(git_model.device)
|
104 |
+
outputs = git_model.generate(**inputs, max_length=50)
|
105 |
+
return git_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
106 |
+
except Exception as e:
|
107 |
+
print(f"Caption generation error: {e}")
|
108 |
+
return "Could not generate caption"
|
109 |
+
|
110 |
def answer_question_from_image(image, question):
|
111 |
if image is None or not question.strip():
|
112 |
return "Please upload an image and ask a question.", None
|
|
|
119 |
answer = " ".join([entry[1] for entry in result]) or "No readable text found."
|
120 |
|
121 |
elif mode == "caption":
|
122 |
+
answer = generate_caption(image)
|
|
|
|
|
123 |
|
124 |
+
else: # VQA mode
|
125 |
inputs = vqa_processor(image, question, return_tensors="pt")
|
126 |
with torch.no_grad():
|
127 |
outputs = vqa_model(**inputs)
|
128 |
predicted_id = outputs.logits.argmax(-1).item()
|
129 |
answer = vqa_model.config.id2label[predicted_id]
|
130 |
|
131 |
+
# Generate audio response
|
132 |
tts = gTTS(text=answer)
|
133 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
|
134 |
tts.save(tmp.name)
|
|
|
139 |
|
140 |
@app.get("/")
|
141 |
def home():
|
142 |
+
return RedirectResponse(url="/templates/home.html")
|