ikraamkb commited on
Commit
70c714b
·
verified ·
1 Parent(s): 6b9ec82

Update appImage.py

Browse files
Files changed (1) hide show
  1. appImage.py +28 -13
appImage.py CHANGED
@@ -63,10 +63,7 @@ from fastapi import FastAPI
63
  from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
64
  import os
65
  from PIL import Image
66
- from transformers import (
67
- ViltProcessor, ViltForQuestionAnswering,
68
- AutoProcessor, GitForCausalLM
69
- )
70
  from gtts import gTTS
71
  import easyocr
72
  import torch
@@ -76,14 +73,22 @@ from io import BytesIO
76
 
77
  app = FastAPI()
78
 
79
- # Models
80
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
81
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
82
- caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
83
- caption_model = GitForCausalLM.from_pretrained("microsoft/git-large-coco")
84
- reader = easyocr.Reader(['en', 'fr'])
 
 
 
 
 
 
 
85
 
86
  def classify_question(question: str):
 
87
  q = question.lower()
88
  if any(w in q for w in ["text", "say", "written", "read"]):
89
  return "ocr"
@@ -91,6 +96,17 @@ def classify_question(question: str):
91
  return "caption"
92
  return "vqa"
93
 
 
 
 
 
 
 
 
 
 
 
 
94
  def answer_question_from_image(image, question):
95
  if image is None or not question.strip():
96
  return "Please upload an image and ask a question.", None
@@ -103,17 +119,16 @@ def answer_question_from_image(image, question):
103
  answer = " ".join([entry[1] for entry in result]) or "No readable text found."
104
 
105
  elif mode == "caption":
106
- image_tensor = caption_processor(images=image, return_tensors="pt").pixel_values
107
- generated_ids = caption_model.generate(image_tensor, max_new_tokens=64)
108
- answer = caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
109
 
110
- else:
111
  inputs = vqa_processor(image, question, return_tensors="pt")
112
  with torch.no_grad():
113
  outputs = vqa_model(**inputs)
114
  predicted_id = outputs.logits.argmax(-1).item()
115
  answer = vqa_model.config.id2label[predicted_id]
116
 
 
117
  tts = gTTS(text=answer)
118
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
119
  tts.save(tmp.name)
@@ -124,4 +139,4 @@ def answer_question_from_image(image, question):
124
 
125
  @app.get("/")
126
  def home():
127
- return RedirectResponse(url="/templates/home.html")
 
63
  from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
64
  import os
65
  from PIL import Image
66
+ from transformers import ViltProcessor, ViltForQuestionAnswering, AutoProcessor, AutoModelForCausalLM
 
 
 
67
  from gtts import gTTS
68
  import easyocr
69
  import torch
 
73
 
74
  app = FastAPI()
75
 
76
+ # Initialize models with optimized settings
77
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
78
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
79
+
80
+ # Load GIT model with performance optimizations
81
+ git_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
82
+ git_model = AutoModelForCausalLM.from_pretrained(
83
+ "microsoft/git-large-coco",
84
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
85
+ device_map="auto"
86
+ )
87
+
88
+ reader = easyocr.Reader(['en', 'fr'], gpu=torch.cuda.is_available())
89
 
90
  def classify_question(question: str):
91
+ """Optimized question classification"""
92
  q = question.lower()
93
  if any(w in q for w in ["text", "say", "written", "read"]):
94
  return "ocr"
 
96
  return "caption"
97
  return "vqa"
98
 
99
+ @torch.inference_mode()
100
+ def generate_caption(image):
101
+ """Optimized caption generation with GIT model"""
102
+ try:
103
+ inputs = git_processor(images=image, return_tensors="pt").to(git_model.device)
104
+ outputs = git_model.generate(**inputs, max_length=50)
105
+ return git_processor.batch_decode(outputs, skip_special_tokens=True)[0]
106
+ except Exception as e:
107
+ print(f"Caption generation error: {e}")
108
+ return "Could not generate caption"
109
+
110
  def answer_question_from_image(image, question):
111
  if image is None or not question.strip():
112
  return "Please upload an image and ask a question.", None
 
119
  answer = " ".join([entry[1] for entry in result]) or "No readable text found."
120
 
121
  elif mode == "caption":
122
+ answer = generate_caption(image)
 
 
123
 
124
+ else: # VQA mode
125
  inputs = vqa_processor(image, question, return_tensors="pt")
126
  with torch.no_grad():
127
  outputs = vqa_model(**inputs)
128
  predicted_id = outputs.logits.argmax(-1).item()
129
  answer = vqa_model.config.id2label[predicted_id]
130
 
131
+ # Generate audio response
132
  tts = gTTS(text=answer)
133
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
134
  tts.save(tmp.name)
 
139
 
140
  @app.get("/")
141
  def home():
142
+ return RedirectResponse(url="/templates/home.html")