Amarthya7's picture
Upload 6 files
e2a4738 verified
raw
history blame
5.7 kB
import io
import os
import traceback
import torch
from PIL import Image, UnidentifiedImageError
from .model_loader import ModelManager
class VQAInference:
"""
Class to perform inference with Visual Question Answering models
"""
def __init__(self, model_name="blip", cache_dir=None):
"""
Initialize the VQA inference
Args:
model_name (str, optional): Name of model to use. Defaults to "blip".
cache_dir (str, optional): Directory to cache models. Defaults to None.
"""
self.model_name = model_name
self.model_manager = ModelManager(cache_dir=cache_dir)
self.processor, self.model = self.model_manager.get_model(model_name)
self.device = self.model_manager.device
def predict(self, image, question):
"""
Perform VQA prediction on an image with a question
Args:
image (PIL.Image.Image or str): Image to analyze or path to image
question (str): Question to ask about the image
Returns:
str: Answer to the question
"""
# Handle image input - could be a file path or PIL Image
if isinstance(image, str):
try:
# Check if file exists
if not os.path.exists(image):
raise FileNotFoundError(f"Image file not found: {image}")
# Try multiple approaches to load the image
try:
# Try the standard approach first
image = Image.open(image).convert("RGB")
print(
f"Successfully opened image: {image.size}, mode: {image.mode}"
)
except Exception as img_err:
print(
f"Standard image loading failed: {img_err}, trying alternative method..."
)
# Try alternative approach with binary mode explicitly
with open(image, "rb") as img_file:
img_data = img_file.read()
image = Image.open(io.BytesIO(img_data)).convert("RGB")
print(
f"Alternative image loading succeeded: {image.size}, mode: {image.mode}"
)
except UnidentifiedImageError as e:
# Specific error when image format cannot be identified
raise ValueError(f"Cannot identify image format: {str(e)}")
except Exception as e:
# Provide detailed error information
error_details = traceback.format_exc()
print(f"Error details: {error_details}")
raise ValueError(f"Could not open image file: {str(e)}")
# Make sure image is a PIL Image
if not isinstance(image, Image.Image):
raise ValueError("Image must be a PIL Image or a file path")
# Process based on model type
if self.model_name.lower() == "blip":
return self._predict_with_blip(image, question)
elif self.model_name.lower() == "vilt":
return self._predict_with_vilt(image, question)
else:
raise ValueError(f"Prediction not implemented for model: {self.model_name}")
def _predict_with_blip(self, image, question):
"""
Perform prediction with BLIP model
Args:
image (PIL.Image.Image): Image to analyze
question (str): Question to ask about the image
Returns:
str: Answer to the question
"""
try:
# Process image and text inputs
inputs = self.processor(
images=image, text=question, return_tensors="pt"
).to(self.device)
# Generate answer
with torch.no_grad():
outputs = self.model.generate(**inputs)
# Decode the output to text
answer = self.processor.decode(outputs[0], skip_special_tokens=True)
return answer
except Exception as e:
error_details = traceback.format_exc()
print(f"Error in BLIP prediction: {str(e)}")
print(f"Error details: {error_details}")
raise RuntimeError(f"BLIP model prediction failed: {str(e)}")
def _predict_with_vilt(self, image, question):
"""
Perform prediction with ViLT model
Args:
image (PIL.Image.Image): Image to analyze
question (str): Question to ask about the image
Returns:
str: Answer to the question
"""
try:
# Process image and text inputs
encoding = self.processor(images=image, text=question, return_tensors="pt")
# Move inputs to device
for k, v in encoding.items():
encoding[k] = v.to(self.device)
# Forward pass
with torch.no_grad():
outputs = self.model(**encoding)
logits = outputs.logits
# Get the predicted answer idx
idx = logits.argmax(-1).item()
# Convert to answer text
answer = self.model.config.id2label[idx]
return answer
except Exception as e:
error_details = traceback.format_exc()
print(f"Error in ViLT prediction: {str(e)}")
print(f"Error details: {error_details}")
raise RuntimeError(f"ViLT model prediction failed: {str(e)}")