# agents/image_processing_agent.py import logging import os import gc import torch from typing import Dict, List, Optional, Tuple, Union, Any import time from concurrent.futures import ThreadPoolExecutor, as_completed from PIL import Image, UnidentifiedImageError class ImageProcessingAgent: def __init__(self, image_model_manager, text_model_manager=None, token_manager=None, cache_manager=None, metrics_calculator=None): """Initialize the ImageProcessingAgent with required model managers and utilities.""" self.logger = logging.getLogger(__name__) self.image_model_manager = image_model_manager self.text_model_manager = text_model_manager self.token_manager = token_manager self.cache_manager = cache_manager self.metrics_calculator = metrics_calculator # Default relevance threshold self.relevance_threshold = 0.4 # Default confidence values self.confidence_high_threshold = 0.7 self.confidence_low_threshold = 0.3 # Agent name for logging self.agent_name = "image_processing_agent" def load_image(self, image_path: str) -> Tuple[Optional[Image.Image], bool, str]: """ Load an image from a file path. Returns a tuple of (image, success, error_message). """ try: image = Image.open(image_path).convert('RGB') return image, True, "" except UnidentifiedImageError: error_msg = f"Unidentified image format: {image_path}" self.logger.error(error_msg) return None, False, error_msg except Exception as e: error_msg = f"Failed to load image: {str(e)}" self.logger.error(error_msg) return None, False, error_msg def process_single_image(self, image_data: Dict[str, Any]) -> Dict[str, Any]: """ Process a single image to generate caption and metadata. Updates the image_data dict in-place and returns it. """ image = image_data.get("image") if image is None or not image_data.get("success", False): return image_data # Generate caption using the image model manager try: result = self.image_model_manager.generate_image_caption( image, agent_name=self.agent_name) # Update image data with caption results image_data["caption"] = result.get("caption", "Failed to generate caption") image_data["model_used"] = result.get("model_used", "unknown") image_data["complexity"] = result.get("complexity", {}) image_data["confidence"] = result.get("confidence", 0.5) except Exception as e: error_msg = f"Error generating caption: {str(e)}" self.logger.error(error_msg) image_data["error"] = error_msg image_data["caption"] = "Caption generation failed" image_data["confidence"] = 0.0 gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return image_data def match_images_to_topic(self, topic: str, image_data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Determine relevance of images to the given topic. Updates image_data_list in-place and returns it. """ if not self.text_model_manager: self.logger.warning("No text model manager available for topic matching") # Set default relevance for img_data in image_data_list: img_data["relevance_score"] = 0.5 img_data["is_relevant"] = True return image_data_list # Extract captions from successfully processed images valid_images = [] valid_indices = [] for i, img_data in enumerate(image_data_list): if img_data.get("caption") and not img_data.get("error"): valid_images.append(img_data) valid_indices.append(i) if not valid_images: self.logger.warning("No valid images with captions for topic matching") return image_data_list # Compute similarity between topic and image captions similarities = self.image_model_manager.match_images_to_topic( topic, valid_images, self.text_model_manager) # Update image data with relevance scores for i, similarity, orig_idx in zip(range(len(similarities)), similarities, valid_indices): image_data_list[orig_idx]["relevance_score"] = float(similarity) image_data_list[orig_idx]["is_relevant"] = similarity >= self.relevance_threshold # Set relevance for invalid images for i, img_data in enumerate(image_data_list): if i not in valid_indices: img_data["relevance_score"] = 0.0 img_data["is_relevant"] = False # Sort images by relevance image_data_list.sort(key=lambda x: x.get("relevance_score", 0), reverse=True) return image_data_list def process_images_batch(self, topic: str, image_data_list: List[Dict[str, Any]], batch_size: int = 1, max_workers: int = 2) -> List[Dict[str, Any]]: """ Process a batch of images in parallel. Returns the processed image data list. """ # Process images in batches to avoid memory issues for i in range(0, len(image_data_list), batch_size): batch = image_data_list[i:i+batch_size] with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = {executor.submit(self.process_single_image, img_data): img_data for img_data in batch} for future in as_completed(futures): try: future.result() # Image data is updated in-place except Exception as e: img_data = futures[future] img_data["error"] = f"Processing error: {str(e)}" img_data["confidence"] = 0.0 self.logger.error(f"Error processing image: {e}") # Match images to topic after all captions are generated image_data_list = self.match_images_to_topic(topic, image_data_list) return image_data_list def generate_image_analysis_report(self, topic: str, image_data_list: List[Dict[str, Any]]) -> Dict[str, Any]: """ Generate a comprehensive report of the image analysis findings. Returns a report dict. """ # Filter for successfully processed images processed_images = [img for img in image_data_list if "caption" in img and not img.get("error")] relevant_images = [img for img in processed_images if img.get("is_relevant", False)] # Calculate overall confidence if relevant_images: avg_confidence = sum(img.get("confidence", 0) for img in relevant_images) / len(relevant_images) avg_relevance = sum(img.get("relevance_score", 0) for img in relevant_images) / len(relevant_images) else: avg_confidence = 0.0 avg_relevance = 0.0 # Combine for overall score overall_confidence = 0.7 * avg_confidence + 0.3 * avg_relevance # Prepare report content report = { "topic": topic, "total_images": len(image_data_list), "successfully_processed": len(processed_images), "relevant_images": len(relevant_images), "overall_confidence": overall_confidence, "confidence_level": self._get_confidence_level(overall_confidence), "image_analyses": [] } # Add individual image analyses for img in relevant_images: img_report = { "filename": img.get("filename", "Unknown"), "caption": img.get("caption", "No caption available"), "model_used": img.get("model_used", "unknown"), "relevance_score": img.get("relevance_score", 0), "confidence": img.get("confidence", 0) } report["image_analyses"].append(img_report) return report def _get_confidence_level(self, confidence_score: float) -> str: """Convert numerical confidence to descriptive level.""" if confidence_score >= self.confidence_high_threshold: return "high" elif confidence_score >= self.confidence_low_threshold: return "medium" else: return "low" def process_image_files(self, topic: str, file_paths: List[str]) -> Dict[str, Any]: """ Main method to process a list of image files for a given topic. Returns a comprehensive analysis report. """ start_time = time.time() self.logger.info(f"Processing {len(file_paths)} image files for topic: {topic}") # Load all images image_data_list = [] for file_path in file_paths: image, success, error_msg = self.load_image(file_path) img_data = { "filename": os.path.basename(file_path), "filepath": file_path, "image": image, "success": success } if not success: img_data["error"] = error_msg image_data_list.append(img_data) # Process images processed_images = self.process_images_batch(topic, image_data_list) # Generate report report = self.generate_image_analysis_report(topic, processed_images) # Add processing metadata processing_time = time.time() - start_time report["processing_time"] = processing_time # Clean up PIL Image objects before returning (to avoid serialization issues) for img_data in processed_images: if "image" in img_data: del img_data["image"] report["processed_images"] = processed_images self.logger.info(f"Completed image analysis in {processing_time:.2f} seconds. " + f"Found {report['relevant_images']} relevant images.") return report