ikraamkb commited on
Commit
09a2e2c
·
verified ·
1 Parent(s): 297dd8a

Update appImage.py

Browse files
Files changed (1) hide show
  1. appImage.py +51 -23
appImage.py CHANGED
@@ -1,32 +1,60 @@
1
- # appImage.py
2
- from transformers import pipeline
3
- import tempfile, os
 
4
  from PIL import Image
 
5
  from gtts import gTTS
 
 
 
 
 
6
 
7
- captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
8
 
9
- async def caption_image(file):
10
- contents = await file.read()
11
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
12
- tmp.write(contents)
13
- image_path = tmp.name
14
 
15
- captions = captioner(image_path)
16
- caption = captions[0]['generated_text'] if captions else "No caption generated."
 
 
 
 
 
17
 
18
- audio_path = text_to_speech(caption)
 
 
19
 
20
- result = {"caption": caption}
21
- if audio_path:
22
- result["audioUrl"] = f"/files/{os.path.basename(audio_path)}"
23
- return result
24
 
25
- def text_to_speech(text: str):
26
  try:
27
- tts = gTTS(text)
28
- temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
29
- tts.save(temp_audio.name)
30
- return temp_audio.name
31
- except:
32
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from fastapi import FastAPI
3
+ from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
4
+ import os
5
  from PIL import Image
6
+ from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline
7
  from gtts import gTTS
8
+ import easyocr
9
+ import torch
10
+ import tempfile
11
+ import numpy as np
12
+ from io import BytesIO
13
 
14
+ app = FastAPI()
15
 
16
+ vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
17
+ vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
18
+ captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
19
+ reader = easyocr.Reader(['en', 'fr'])
 
20
 
21
+ def classify_question(question: str):
22
+ q = question.lower()
23
+ if any(w in q for w in ["text", "say", "written", "read"]):
24
+ return "ocr"
25
+ if any(w in q for w in ["caption", "describe", "what is in the image"]):
26
+ return "caption"
27
+ return "vqa"
28
 
29
+ def answer_question_from_image(image, question):
30
+ if image is None or not question.strip():
31
+ return "Please upload an image and ask a question.", None
32
 
33
+ mode = classify_question(question)
 
 
 
34
 
 
35
  try:
36
+ if mode == "ocr":
37
+ result = reader.readtext(np.array(image))
38
+ answer = " ".join([entry[1] for entry in result]) or "No readable text found."
39
+
40
+ elif mode == "caption":
41
+ answer = captioner(image)[0]['generated_text']
42
+
43
+ else:
44
+ inputs = vqa_processor(image, question, return_tensors="pt")
45
+ with torch.no_grad():
46
+ outputs = vqa_model(**inputs)
47
+ predicted_id = outputs.logits.argmax(-1).item()
48
+ answer = vqa_model.config.id2label[predicted_id]
49
+
50
+ tts = gTTS(text=answer)
51
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
52
+ tts.save(tmp.name)
53
+ return answer, tmp.name
54
+
55
+ except Exception as e:
56
+ return f"Error: {e}", None
57
+
58
+ @app.get("/")
59
+ def home():
60
+ return RedirectResponse(url="/templates/home.html")