# models/image_models.py import logging import os import torch from typing import Dict, List, Optional, Tuple, Union, Any from PIL import Image import numpy as np from transformers import BlipProcessor, BlipForConditionalGeneration from transformers import Blip2Processor, Blip2ForConditionalGeneration class ImageModelManager: def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None): """Initialize the ImageModelManager with optional utilities.""" self.logger = logging.getLogger(__name__) self.token_manager = token_manager self.cache_manager = cache_manager self.metrics_calculator = metrics_calculator # Model instances self.lightweight_model = None self.lightweight_processor = None self.advanced_model = None self.advanced_processor = None # Model names self.lightweight_model_name = "Salesforce/blip-image-captioning-base" self.advanced_model_name = "Salesforce/blip2-opt-2.7b" # Track initialization state self.initialized = { "lightweight": False, "advanced": False } # Default complexity thresholds self.complexity_thresholds = { "entropy": 4.5, # Higher entropy suggests more complex image "edge_density": 0.15, # Higher edge density suggests more details "size": 500000 # Larger images may contain more information } def initialize_lightweight_model(self): """Initialize the lightweight image captioning model.""" if self.initialized["lightweight"]: return try: # Register with token manager if available if self.token_manager: self.token_manager.register_model( self.lightweight_model_name, "image_captioning") # Load model and processor self.logger.info(f"Loading lightweight image model: {self.lightweight_model_name}") self.lightweight_processor = BlipProcessor.from_pretrained(self.lightweight_model_name) self.lightweight_model = BlipForConditionalGeneration.from_pretrained( self.lightweight_model_name, torch_dtype=torch.float32) self.initialized["lightweight"] = True self.logger.info("Lightweight image model initialized successfully") except Exception as e: self.logger.error(f"Failed to initialize lightweight image model: {e}") raise def initialize_advanced_model(self): """Initialize the advanced image captioning model.""" if self.initialized["advanced"]: return try: # Register with token manager if available if self.token_manager: self.token_manager.register_model( self.advanced_model_name, "image_captioning") # Load model and processor self.logger.info(f"Loading advanced image model: {self.advanced_model_name}") self.advanced_processor = Blip2Processor.from_pretrained(self.advanced_model_name) self.advanced_model = Blip2ForConditionalGeneration.from_pretrained( self.advanced_model_name, torch_dtype=torch.float32) self.initialized["advanced"] = True self.logger.info("Advanced image model initialized successfully") except Exception as e: self.logger.error(f"Failed to initialize advanced image model: {e}") raise def determine_image_complexity(self, image: Image.Image) -> Dict[str, float]: """ Determine the complexity of an image to guide model selection. Returns complexity metrics. """ # Convert to numpy array img_array = np.array(image.convert("L")) # Convert to grayscale for analysis # Calculate image entropy (measure of randomness/information) histogram = np.histogram(img_array, bins=256, range=(0, 256))[0] histogram = histogram / histogram.sum() non_zero = histogram > 0 entropy = -np.sum(histogram[non_zero] * np.log2(histogram[non_zero])) # Calculate edge density using simple gradient method gradient_x = np.abs(np.diff(img_array, axis=1, prepend=0)) gradient_y = np.abs(np.diff(img_array, axis=0, prepend=0)) gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2) edge_density = np.mean(gradient_magnitude > 30) # Threshold for edge detection # Get image size in pixels size = image.width * image.height return { "entropy": float(entropy), "edge_density": float(edge_density), "size": size } def select_captioning_model(self, image: Image.Image) -> str: """ Select the appropriate captioning model based on image complexity. Returns model type ("lightweight" or "advanced"). """ # Get complexity metrics complexity = self.determine_image_complexity(image) # Decision logic for model selection use_advanced = ( complexity["entropy"] > self.complexity_thresholds["entropy"] or complexity["edge_density"] > self.complexity_thresholds["edge_density"] or complexity["size"] > self.complexity_thresholds["size"] ) # Log selection decision model_type = "advanced" if use_advanced else "lightweight" self.logger.info(f"Selected {model_type} model for image captioning (complexity: {complexity})") # If metrics calculator is available, log model selection if use_advanced and self.metrics_calculator: # Estimate energy saved if we had used the advanced model # This is a negative number since we're using more energy energy_diff = -0.01 # Approximate difference in watt-hours self.metrics_calculator.log_model_downgrade( self.advanced_model_name, self.lightweight_model_name, energy_diff) return model_type def generate_image_caption(self, image: Union[str, Image.Image], agent_name: str = "image_processing") -> Dict[str, Any]: """ Generate caption for an image, selecting appropriate model based on complexity. Returns caption and metadata. """ # Handle string input (file path) if isinstance(image, str): if os.path.exists(image): image = Image.open(image).convert('RGB') else: raise ValueError(f"Image file not found: {image}") # Ensure image is PIL Image if not isinstance(image, Image.Image): raise TypeError("Image must be a PIL Image or a valid file path") # Check cache if available image_hash = str(hash(image.tobytes())) if self.cache_manager: cache_hit, cached_result = self.cache_manager.get( image_hash, namespace="image_captions") if cache_hit: # Update metrics if available if self.metrics_calculator: self.metrics_calculator.update_cache_metrics(1, 0, 0.01) # Estimated energy saving return cached_result # Select model based on image complexity model_type = self.select_captioning_model(image) # Initialize selected model if needed if model_type == "advanced": if not self.initialized["advanced"]: self.initialize_advanced_model() processor = self.advanced_processor model = self.advanced_model model_name = self.advanced_model_name else: if not self.initialized["lightweight"]: self.initialize_lightweight_model() processor = self.lightweight_processor model = self.lightweight_model model_name = self.lightweight_model_name # Process image inputs = processor(image, return_tensors="pt") # Request token budget if available if self.token_manager: # Estimate token usage (approximate) estimated_tokens = 50 # Base tokens for generation approved, reason = self.token_manager.request_tokens( agent_name, "image_captioning", "", model_name) if not approved: self.logger.warning(f"Token budget exceeded: {reason}") return {"caption": "Token budget exceeded", "error": reason} # Generate caption with torch.no_grad(): if model_type == "advanced": pixel_values = inputs.pixel_values.to(torch.float32) generated_ids = model.generate( pixel_values=inputs.pixel_values, max_new_tokens=50, # Using max_new_tokens instead of max_length num_beams=5 ) caption = processor.decode(generated_ids[0], skip_special_tokens=True) else: outputs = model.generate( **inputs, max_new_tokens=50, # Using max_new_tokens instead of max_length num_beams=5 ) caption = processor.decode(outputs[0], skip_special_tokens=True) # # Generate caption # with torch.no_grad(): # if model_type == "advanced": # generated_ids = model.generate( # pixel_values=inputs.pixel_values, # max_length=30, # num_beams=5 # ) # caption = processor.decode(generated_ids[0], skip_special_tokens=True) # else: # outputs = model.generate(**inputs, max_length=30, num_beams=5) # caption = processor.decode(outputs[0], skip_special_tokens=True) # Prepare result result = { "caption": caption, "model_used": model_type, "complexity": self.determine_image_complexity(image), "confidence": 0.9 if model_type == "advanced" else 0.7 # Estimated confidence } # Log token usage if available if self.token_manager: # Approximate token count based on output length token_count = len(caption.split()) + 20 # Base tokens + output self.token_manager.log_usage( agent_name, "image_captioning", token_count, model_name) # Log energy usage if metrics calculator is available if self.metrics_calculator: energy_usage = self.token_manager.calculate_energy_usage( token_count, model_name) self.metrics_calculator.log_energy_usage( energy_usage, model_name, agent_name, "image_captioning") # Store in cache if available if self.cache_manager: self.cache_manager.put(image_hash, result, namespace="image_captions") return result def match_images_to_topic(self, topic: str, image_captions: List[Dict[str, Any]], text_model_manager=None) -> List[float]: """ Match image captions to the user's topic using semantic similarity. Returns relevance scores for each image. """ if not text_model_manager: self.logger.warning("No text model manager provided for semantic matching") return [0.5] * len(image_captions) # Default mid-range relevance # Extract captions captions = [item["caption"] for item in image_captions] # Use text model to compute similarity similarities = text_model_manager.compute_similarity(topic, captions) return similarities