File size: 12,348 Bytes
71aaa5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5aea4f1
71aaa5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5aea4f1
71aaa5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
# models/image_models.py
import logging
import os
import torch
from typing import Dict, List, Optional, Tuple, Union, Any
from PIL import Image
import numpy as np
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import Blip2Processor, Blip2ForConditionalGeneration

class ImageModelManager:
    def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None):
        """Initialize the ImageModelManager with optional utilities."""
        self.logger = logging.getLogger(__name__)
        self.token_manager = token_manager
        self.cache_manager = cache_manager
        self.metrics_calculator = metrics_calculator
        
        # Model instances
        self.lightweight_model = None
        self.lightweight_processor = None
        self.advanced_model = None
        self.advanced_processor = None
        
        # Model names
        self.lightweight_model_name = "Salesforce/blip-image-captioning-base"
        self.advanced_model_name = "Salesforce/blip2-opt-2.7b"
        
        # Track initialization state
        self.initialized = {
            "lightweight": False,
            "advanced": False
        }
        
        # Default complexity thresholds
        self.complexity_thresholds = {
            "entropy": 4.5,       # Higher entropy suggests more complex image
            "edge_density": 0.15,  # Higher edge density suggests more details
            "size": 500000        # Larger images may contain more information
        }
        
    def initialize_lightweight_model(self):
        """Initialize the lightweight image captioning model."""
        if self.initialized["lightweight"]:
            return
            
        try:
            # Register with token manager if available
            if self.token_manager:
                self.token_manager.register_model(
                    self.lightweight_model_name, "image_captioning")
                
            # Load model and processor
            self.logger.info(f"Loading lightweight image model: {self.lightweight_model_name}")
            self.lightweight_processor = BlipProcessor.from_pretrained(self.lightweight_model_name)
            self.lightweight_model = BlipForConditionalGeneration.from_pretrained(
                self.lightweight_model_name, torch_dtype=torch.float32)
                
            self.initialized["lightweight"] = True
            self.logger.info("Lightweight image model initialized successfully")
            
        except Exception as e:
            self.logger.error(f"Failed to initialize lightweight image model: {e}")
            raise
            
    def initialize_advanced_model(self):
        """Initialize the advanced image captioning model."""
        if self.initialized["advanced"]:
            return
            
        try:
            # Register with token manager if available
            if self.token_manager:
                self.token_manager.register_model(
                    self.advanced_model_name, "image_captioning")
                
            # Load model and processor
            self.logger.info(f"Loading advanced image model: {self.advanced_model_name}")
            self.advanced_processor = Blip2Processor.from_pretrained(self.advanced_model_name)
            self.advanced_model = Blip2ForConditionalGeneration.from_pretrained(
                self.advanced_model_name, torch_dtype=torch.float32)
                
            self.initialized["advanced"] = True
            self.logger.info("Advanced image model initialized successfully")
            
        except Exception as e:
            self.logger.error(f"Failed to initialize advanced image model: {e}")
            raise
    
    def determine_image_complexity(self, image: Image.Image) -> Dict[str, float]:
        """
        Determine the complexity of an image to guide model selection.
        Returns complexity metrics.
        """
        # Convert to numpy array
        img_array = np.array(image.convert("L"))  # Convert to grayscale for analysis
        
        # Calculate image entropy (measure of randomness/information)
        histogram = np.histogram(img_array, bins=256, range=(0, 256))[0]
        histogram = histogram / histogram.sum()
        non_zero = histogram > 0
        entropy = -np.sum(histogram[non_zero] * np.log2(histogram[non_zero]))
        
        # Calculate edge density using simple gradient method
        gradient_x = np.abs(np.diff(img_array, axis=1, prepend=0))
        gradient_y = np.abs(np.diff(img_array, axis=0, prepend=0))
        gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)
        edge_density = np.mean(gradient_magnitude > 30) # Threshold for edge detection
        
        # Get image size in pixels
        size = image.width * image.height
        
        return {
            "entropy": float(entropy),
            "edge_density": float(edge_density),
            "size": size
        }
    
    def select_captioning_model(self, image: Image.Image) -> str:
        """
        Select the appropriate captioning model based on image complexity.
        Returns model type ("lightweight" or "advanced").
        """
        # Get complexity metrics
        complexity = self.determine_image_complexity(image)
        
        # Decision logic for model selection
        use_advanced = (
            complexity["entropy"] > self.complexity_thresholds["entropy"] or
            complexity["edge_density"] > self.complexity_thresholds["edge_density"] or
            complexity["size"] > self.complexity_thresholds["size"]
        )
        
        # Log selection decision
        model_type = "advanced" if use_advanced else "lightweight"
        self.logger.info(f"Selected {model_type} model for image captioning (complexity: {complexity})")
        
        # If metrics calculator is available, log model selection
        if use_advanced and self.metrics_calculator:
            # Estimate energy saved if we had used the advanced model
            # This is a negative number since we're using more energy
            energy_diff = -0.01  # Approximate difference in watt-hours
            self.metrics_calculator.log_model_downgrade(
                self.advanced_model_name, self.lightweight_model_name, energy_diff)
        
        return model_type
    
    def generate_image_caption(self, image: Union[str, Image.Image], 
                               agent_name: str = "image_processing") -> Dict[str, Any]:
        """
        Generate caption for an image, selecting appropriate model based on complexity.
        Returns caption and metadata.
        """
        # Handle string input (file path)
        if isinstance(image, str):
            if os.path.exists(image):
                image = Image.open(image).convert('RGB')
            else:
                raise ValueError(f"Image file not found: {image}")
                
        # Ensure image is PIL Image
        if not isinstance(image, Image.Image):
            raise TypeError("Image must be a PIL Image or a valid file path")
            
        # Check cache if available
        image_hash = str(hash(image.tobytes()))
        if self.cache_manager:
            cache_hit, cached_result = self.cache_manager.get(
                image_hash, namespace="image_captions")
                
            if cache_hit:
                # Update metrics if available
                if self.metrics_calculator:
                    self.metrics_calculator.update_cache_metrics(1, 0, 0.01)  # Estimated energy saving
                return cached_result
        
        # Select model based on image complexity
        model_type = self.select_captioning_model(image)
        
        # Initialize selected model if needed
        if model_type == "advanced":
            if not self.initialized["advanced"]:
                self.initialize_advanced_model()
                
            processor = self.advanced_processor
            model = self.advanced_model
            model_name = self.advanced_model_name
        else:
            if not self.initialized["lightweight"]:
                self.initialize_lightweight_model()
                
            processor = self.lightweight_processor
            model = self.lightweight_model
            model_name = self.lightweight_model_name
            
        # Process image
        inputs = processor(image, return_tensors="pt")
        
        # Request token budget if available
        if self.token_manager:
            # Estimate token usage (approximate)
            estimated_tokens = 50  # Base tokens for generation
            approved, reason = self.token_manager.request_tokens(
                agent_name, "image_captioning", "", model_name)
                
            if not approved:
                self.logger.warning(f"Token budget exceeded: {reason}")
                return {"caption": "Token budget exceeded", "error": reason}
        
        
        # Generate caption
        with torch.no_grad():
            if model_type == "advanced":
                pixel_values = inputs.pixel_values.to(torch.float32)
                generated_ids = model.generate(
                    pixel_values=inputs.pixel_values,
                    max_new_tokens=50,  # Using max_new_tokens instead of max_length
                    num_beams=5
                )
                caption = processor.decode(generated_ids[0], skip_special_tokens=True)
            else:
                outputs = model.generate(
                    **inputs, 
                    max_new_tokens=50,  # Using max_new_tokens instead of max_length
                    num_beams=5
                )
                caption = processor.decode(outputs[0], skip_special_tokens=True)
                # # Generate caption
        # with torch.no_grad():
        #     if model_type == "advanced":
        #         generated_ids = model.generate(
        #             pixel_values=inputs.pixel_values,
        #             max_length=30,
        #             num_beams=5
        #         )
        #         caption = processor.decode(generated_ids[0], skip_special_tokens=True)
        #     else:
        #         outputs = model.generate(**inputs, max_length=30, num_beams=5)
        #         caption = processor.decode(outputs[0], skip_special_tokens=True)
        
        # Prepare result
        result = {
            "caption": caption,
            "model_used": model_type,
            "complexity": self.determine_image_complexity(image),
            "confidence": 0.9 if model_type == "advanced" else 0.7  # Estimated confidence
        }
        
        # Log token usage if available
        if self.token_manager:
            # Approximate token count based on output length
            token_count = len(caption.split()) + 20  # Base tokens + output
            self.token_manager.log_usage(
                agent_name, "image_captioning", token_count, model_name)
                
            # Log energy usage if metrics calculator is available
            if self.metrics_calculator:
                energy_usage = self.token_manager.calculate_energy_usage(
                    token_count, model_name)
                self.metrics_calculator.log_energy_usage(
                    energy_usage, model_name, agent_name, "image_captioning")
        
        # Store in cache if available
        if self.cache_manager:
            self.cache_manager.put(image_hash, result, namespace="image_captions")
            
        return result
    
    def match_images_to_topic(self, topic: str, image_captions: List[Dict[str, Any]], 
                             text_model_manager=None) -> List[float]:
        """
        Match image captions to the user's topic using semantic similarity.
        Returns relevance scores for each image.
        """
        if not text_model_manager:
            self.logger.warning("No text model manager provided for semantic matching")
            return [0.5] * len(image_captions)  # Default mid-range relevance
            
        # Extract captions
        captions = [item["caption"] for item in image_captions]
        
        # Use text model to compute similarity
        similarities = text_model_manager.compute_similarity(topic, captions)
        
        return similarities