Spaces:
Sleeping
Sleeping
# models/image_models.py | |
import logging | |
import os | |
import torch | |
from typing import Dict, List, Optional, Tuple, Union, Any | |
from PIL import Image | |
import numpy as np | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
class ImageModelManager: | |
def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None): | |
"""Initialize the ImageModelManager with optional utilities.""" | |
self.logger = logging.getLogger(__name__) | |
self.token_manager = token_manager | |
self.cache_manager = cache_manager | |
self.metrics_calculator = metrics_calculator | |
# Model instances | |
self.lightweight_model = None | |
self.lightweight_processor = None | |
self.advanced_model = None | |
self.advanced_processor = None | |
# Model names | |
self.lightweight_model_name = "Salesforce/blip-image-captioning-base" | |
self.advanced_model_name = "Salesforce/blip2-opt-2.7b" | |
# Track initialization state | |
self.initialized = { | |
"lightweight": False, | |
"advanced": False | |
} | |
# Default complexity thresholds | |
self.complexity_thresholds = { | |
"entropy": 4.5, # Higher entropy suggests more complex image | |
"edge_density": 0.15, # Higher edge density suggests more details | |
"size": 500000 # Larger images may contain more information | |
} | |
def initialize_lightweight_model(self): | |
"""Initialize the lightweight image captioning model.""" | |
if self.initialized["lightweight"]: | |
return | |
try: | |
# Register with token manager if available | |
if self.token_manager: | |
self.token_manager.register_model( | |
self.lightweight_model_name, "image_captioning") | |
# Load model and processor | |
self.logger.info(f"Loading lightweight image model: {self.lightweight_model_name}") | |
self.lightweight_processor = BlipProcessor.from_pretrained(self.lightweight_model_name) | |
self.lightweight_model = BlipForConditionalGeneration.from_pretrained( | |
self.lightweight_model_name, torch_dtype=torch.float32) | |
self.initialized["lightweight"] = True | |
self.logger.info("Lightweight image model initialized successfully") | |
except Exception as e: | |
self.logger.error(f"Failed to initialize lightweight image model: {e}") | |
raise | |
def initialize_advanced_model(self): | |
"""Initialize the advanced image captioning model.""" | |
if self.initialized["advanced"]: | |
return | |
try: | |
# Register with token manager if available | |
if self.token_manager: | |
self.token_manager.register_model( | |
self.advanced_model_name, "image_captioning") | |
# Load model and processor | |
self.logger.info(f"Loading advanced image model: {self.advanced_model_name}") | |
self.advanced_processor = Blip2Processor.from_pretrained(self.advanced_model_name) | |
self.advanced_model = Blip2ForConditionalGeneration.from_pretrained( | |
self.advanced_model_name, torch_dtype=torch.float32) | |
self.initialized["advanced"] = True | |
self.logger.info("Advanced image model initialized successfully") | |
except Exception as e: | |
self.logger.error(f"Failed to initialize advanced image model: {e}") | |
raise | |
def determine_image_complexity(self, image: Image.Image) -> Dict[str, float]: | |
""" | |
Determine the complexity of an image to guide model selection. | |
Returns complexity metrics. | |
""" | |
# Convert to numpy array | |
img_array = np.array(image.convert("L")) # Convert to grayscale for analysis | |
# Calculate image entropy (measure of randomness/information) | |
histogram = np.histogram(img_array, bins=256, range=(0, 256))[0] | |
histogram = histogram / histogram.sum() | |
non_zero = histogram > 0 | |
entropy = -np.sum(histogram[non_zero] * np.log2(histogram[non_zero])) | |
# Calculate edge density using simple gradient method | |
gradient_x = np.abs(np.diff(img_array, axis=1, prepend=0)) | |
gradient_y = np.abs(np.diff(img_array, axis=0, prepend=0)) | |
gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2) | |
edge_density = np.mean(gradient_magnitude > 30) # Threshold for edge detection | |
# Get image size in pixels | |
size = image.width * image.height | |
return { | |
"entropy": float(entropy), | |
"edge_density": float(edge_density), | |
"size": size | |
} | |
def select_captioning_model(self, image: Image.Image) -> str: | |
""" | |
Select the appropriate captioning model based on image complexity. | |
Returns model type ("lightweight" or "advanced"). | |
""" | |
# Get complexity metrics | |
complexity = self.determine_image_complexity(image) | |
# Decision logic for model selection | |
use_advanced = ( | |
complexity["entropy"] > self.complexity_thresholds["entropy"] or | |
complexity["edge_density"] > self.complexity_thresholds["edge_density"] or | |
complexity["size"] > self.complexity_thresholds["size"] | |
) | |
# Log selection decision | |
model_type = "advanced" if use_advanced else "lightweight" | |
self.logger.info(f"Selected {model_type} model for image captioning (complexity: {complexity})") | |
# If metrics calculator is available, log model selection | |
if use_advanced and self.metrics_calculator: | |
# Estimate energy saved if we had used the advanced model | |
# This is a negative number since we're using more energy | |
energy_diff = -0.01 # Approximate difference in watt-hours | |
self.metrics_calculator.log_model_downgrade( | |
self.advanced_model_name, self.lightweight_model_name, energy_diff) | |
return model_type | |
def generate_image_caption(self, image: Union[str, Image.Image], | |
agent_name: str = "image_processing") -> Dict[str, Any]: | |
""" | |
Generate caption for an image, selecting appropriate model based on complexity. | |
Returns caption and metadata. | |
""" | |
# Handle string input (file path) | |
if isinstance(image, str): | |
if os.path.exists(image): | |
image = Image.open(image).convert('RGB') | |
else: | |
raise ValueError(f"Image file not found: {image}") | |
# Ensure image is PIL Image | |
if not isinstance(image, Image.Image): | |
raise TypeError("Image must be a PIL Image or a valid file path") | |
# Check cache if available | |
image_hash = str(hash(image.tobytes())) | |
if self.cache_manager: | |
cache_hit, cached_result = self.cache_manager.get( | |
image_hash, namespace="image_captions") | |
if cache_hit: | |
# Update metrics if available | |
if self.metrics_calculator: | |
self.metrics_calculator.update_cache_metrics(1, 0, 0.01) # Estimated energy saving | |
return cached_result | |
# Select model based on image complexity | |
model_type = self.select_captioning_model(image) | |
# Initialize selected model if needed | |
if model_type == "advanced": | |
if not self.initialized["advanced"]: | |
self.initialize_advanced_model() | |
processor = self.advanced_processor | |
model = self.advanced_model | |
model_name = self.advanced_model_name | |
else: | |
if not self.initialized["lightweight"]: | |
self.initialize_lightweight_model() | |
processor = self.lightweight_processor | |
model = self.lightweight_model | |
model_name = self.lightweight_model_name | |
# Process image | |
inputs = processor(image, return_tensors="pt") | |
# Request token budget if available | |
if self.token_manager: | |
# Estimate token usage (approximate) | |
estimated_tokens = 50 # Base tokens for generation | |
approved, reason = self.token_manager.request_tokens( | |
agent_name, "image_captioning", "", model_name) | |
if not approved: | |
self.logger.warning(f"Token budget exceeded: {reason}") | |
return {"caption": "Token budget exceeded", "error": reason} | |
# Generate caption | |
with torch.no_grad(): | |
if model_type == "advanced": | |
pixel_values = inputs.pixel_values.to(torch.float32) | |
generated_ids = model.generate( | |
pixel_values=inputs.pixel_values, | |
max_new_tokens=50, # Using max_new_tokens instead of max_length | |
num_beams=5 | |
) | |
caption = processor.decode(generated_ids[0], skip_special_tokens=True) | |
else: | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=50, # Using max_new_tokens instead of max_length | |
num_beams=5 | |
) | |
caption = processor.decode(outputs[0], skip_special_tokens=True) | |
# # Generate caption | |
# with torch.no_grad(): | |
# if model_type == "advanced": | |
# generated_ids = model.generate( | |
# pixel_values=inputs.pixel_values, | |
# max_length=30, | |
# num_beams=5 | |
# ) | |
# caption = processor.decode(generated_ids[0], skip_special_tokens=True) | |
# else: | |
# outputs = model.generate(**inputs, max_length=30, num_beams=5) | |
# caption = processor.decode(outputs[0], skip_special_tokens=True) | |
# Prepare result | |
result = { | |
"caption": caption, | |
"model_used": model_type, | |
"complexity": self.determine_image_complexity(image), | |
"confidence": 0.9 if model_type == "advanced" else 0.7 # Estimated confidence | |
} | |
# Log token usage if available | |
if self.token_manager: | |
# Approximate token count based on output length | |
token_count = len(caption.split()) + 20 # Base tokens + output | |
self.token_manager.log_usage( | |
agent_name, "image_captioning", token_count, model_name) | |
# Log energy usage if metrics calculator is available | |
if self.metrics_calculator: | |
energy_usage = self.token_manager.calculate_energy_usage( | |
token_count, model_name) | |
self.metrics_calculator.log_energy_usage( | |
energy_usage, model_name, agent_name, "image_captioning") | |
# Store in cache if available | |
if self.cache_manager: | |
self.cache_manager.put(image_hash, result, namespace="image_captions") | |
return result | |
def match_images_to_topic(self, topic: str, image_captions: List[Dict[str, Any]], | |
text_model_manager=None) -> List[float]: | |
""" | |
Match image captions to the user's topic using semantic similarity. | |
Returns relevance scores for each image. | |
""" | |
if not text_model_manager: | |
self.logger.warning("No text model manager provided for semantic matching") | |
return [0.5] * len(image_captions) # Default mid-range relevance | |
# Extract captions | |
captions = [item["caption"] for item in image_captions] | |
# Use text model to compute similarity | |
similarities = text_model_manager.compute_similarity(topic, captions) | |
return similarities | |