File size: 11,408 Bytes
a40ee8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# models/summary_models.py
import logging
from typing import Dict, List, Optional, Tuple, Union, Any
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

class SummaryModelManager:
    def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None):
        """Initialize the SummaryModelManager with optional utilities."""
        self.logger = logging.getLogger(__name__)
        self.token_manager = token_manager
        self.cache_manager = cache_manager
        self.metrics_calculator = metrics_calculator
        
        # Model instance
        self.model = None
        self.tokenizer = None
        
        # Model name
        self.model_name = "t5-small"
        
        # Track initialization state
        self.initialized = False
        
        # Default generation parameters
        self.default_params = {
            "max_length": 150,
            "min_length": 40,
            "length_penalty": 2.0,
            "num_beams": 4,
            "early_stopping": True
        }
        
    def initialize_model(self):
        """Initialize the summarization model."""
        if self.initialized:
            return
            
        try:
            # Register with token manager if available
            if self.token_manager:
                self.token_manager.register_model(
                    self.model_name, "summarization")
                
            # Load model and tokenizer
            self.logger.info(f"Loading summary model: {self.model_name}")
            self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
            self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
                
            self.initialized = True
            self.logger.info("Summary model initialized successfully")
            
        except Exception as e:
            #self.logger.error(f"Failed to initialize summary model: {e}")
            #raise
            # Try a fallback model that doesn't require SentencePiece
            try:
                fallback_model = "facebook/bart-base"
                self.logger.info(f"Trying fallback model: {fallback_model}")
    
                from transformers import BartTokenizer, BartForConditionalGeneration
    
                self.tokenizer = BartTokenizer.from_pretrained(fallback_model)
                self.model = BartForConditionalGeneration.from_pretrained(fallback_model)
                self.model_name = fallback_model
    
                # Register fallback with token manager
                if self.token_manager:
                    self.token_manager.register_model(
                        self.model_name, "summarization")
    
                self.initialized = True
                self.logger.info("Fallback summary model initialized successfully")
            except Exception as fallback_error:
                self.logger.error(f"Failed to initialize fallback model: {fallback_error}")
                raise
    
    def generate_summary(self, text: str, prefix: str = "summarize: ", 
                        agent_name: str = "report_generation", 
                        params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """
        Generate a summary of the given text.
        Returns the summary and metadata.
        """
        # Initialize model if needed
        if not self.initialized:
            self.initialize_model()
            
        # Prepare input text
        input_text = f"{prefix}{text}"
        
        # Check cache if available
        if self.cache_manager:
            cache_key = input_text[:100] + str(hash(input_text))  # Use prefix of text + hash as key
            cache_hit, cached_result = self.cache_manager.get(
                cache_key, namespace="summaries")
                
            if cache_hit:
                # Update metrics if available
                if self.metrics_calculator:
                    self.metrics_calculator.update_cache_metrics(1, 0, 0.005)  # Estimated energy saving
                return cached_result
        
        # Request token budget if available
        if self.token_manager:
            approved, reason = self.token_manager.request_tokens(
                agent_name, "summarization", input_text, self.model_name)
                
            if not approved:
                self.logger.warning(f"Token budget exceeded: {reason}")
                return {"summary": "Token budget exceeded", "error": reason}
        
        # Tokenize
        inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
        
        # Merge default and custom parameters
        generation_params = self.default_params.copy()
        if params:
            generation_params.update(params)
            
        # Generate summary
        with torch.no_grad():
            output_ids = self.model.generate(
                inputs.input_ids,
                **generation_params
            )
            
        # Decode summary
        summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        # Calculate compression ratio
        input_length = len(text.split())
        summary_length = len(summary.split())
        compression_ratio = input_length / max(summary_length, 1)
        
        # Prepare result
        result = {
            "summary": summary,
            "input_length": input_length,
            "summary_length": summary_length,
            "compression_ratio": compression_ratio
        }
        
        # Log token usage if available
        if self.token_manager:
            input_tokens = len(inputs.input_ids[0])
            output_tokens = len(output_ids[0])
            total_tokens = input_tokens + output_tokens
            
            self.token_manager.log_usage(
                agent_name, "summarization", total_tokens, self.model_name)
                
            # Log energy usage if metrics calculator is available
            if self.metrics_calculator:
                energy_usage = self.token_manager.calculate_energy_usage(
                    total_tokens, self.model_name)
                self.metrics_calculator.log_energy_usage(
                    energy_usage, self.model_name, agent_name, "summarization")
        
        # Store in cache if available
        if self.cache_manager:
            self.cache_manager.put(cache_key, result, namespace="summaries")
            
        return result
    
    def generate_executive_summary(self, detailed_content: str, confidence_level: float,
                                  agent_name: str = "report_generation") -> Dict[str, Any]:
        """
        Generate an executive summary with confidence indication.
        Adjusts detail level based on confidence.
        """
        # Prepare prompt based on confidence
        if confidence_level >= 0.7:
            prefix = "summarize with high confidence: "
            params = {"min_length": 30, "max_length": 100}
        elif confidence_level >= 0.4:
            prefix = "summarize with moderate confidence: "
            params = {"min_length": 20, "max_length": 80}
        else:
            prefix = "summarize with low confidence: "
            params = {"min_length": 15, "max_length": 60}
            
        # Generate summary
        result = self.generate_summary(detailed_content, prefix=prefix, 
                                      agent_name=agent_name, params=params)
        
        # Add confidence level to result
        result["confidence_level"] = confidence_level
        
        # Add confidence statement
        confidence_statement = self._generate_confidence_statement(confidence_level)
        result["confidence_statement"] = confidence_statement
        
        return result
    
    def _generate_confidence_statement(self, confidence_level: float) -> str:
        """Generate an appropriate confidence statement based on the level."""
        if confidence_level >= 0.8:
            return "This analysis is provided with high confidence based on strong evidence in the provided materials."
        elif confidence_level >= 0.6:
            return "This analysis is provided with good confidence based on substantial evidence in the provided materials."
        elif confidence_level >= 0.4:
            return "This analysis is provided with moderate confidence. Some aspects may require additional verification."
        elif confidence_level >= 0.2:
            return "This analysis is provided with limited confidence due to sparse relevant information in the provided materials."
        else:
            return "This analysis is provided with very low confidence due to insufficient relevant information in the provided materials."
    
    def combine_analyses(self, text_analyses: List[Dict[str, Any]], 
                        image_analyses: List[Dict[str, Any]],
                        topic: str, agent_name: str = "report_generation") -> Dict[str, Any]:
        """
        Combine text and image analyses into a coherent report.
        Returns the combined report with metadata.
        """
        # Build combined content
        combined_content = f"Topic: {topic}\n\n"
        
        # Add text analyses
        combined_content += "Text Analysis:\n"
        for i, analysis in enumerate(text_analyses):
            if "error" in analysis:
                continue
            combined_content += f"- Document {i+1}: {analysis.get('summary', 'No summary available')}\n"
            
        # Add image analyses
        combined_content += "\nImage Analysis:\n"
        for i, analysis in enumerate(image_analyses):
            if "error" in analysis:
                continue
            combined_content += f"- Image {i+1}: {analysis.get('caption', 'No caption available')}\n"
            
        # Calculate overall confidence based on analyses
        text_confidence = sum(a.get("confidence", 0) for a in text_analyses) / max(len(text_analyses), 1)
        image_confidence = sum(a.get("confidence", 0) for a in image_analyses) / max(len(image_analyses), 1)
        
        # Weight confidence (text analyses typically more important for deep dives)
        overall_confidence = 0.7 * text_confidence + 0.3 * image_confidence
        
        # Generate detailed report
        detailed_report = self.generate_summary(
            combined_content, 
            prefix=f"generate detailed report about {topic}: ", 
            agent_name=agent_name,
            params={"max_length": 300, "min_length": 100}
        )
        
        # Generate executive summary
        executive_summary = self.generate_executive_summary(
            detailed_report["summary"], 
            overall_confidence,
            agent_name
        )
        
        # Combine results
        result = {
            "topic": topic,
            "executive_summary": executive_summary["summary"],
            "confidence_statement": executive_summary["confidence_statement"],
            "detailed_report": detailed_report["summary"],
            "confidence_level": overall_confidence,
            "text_confidence": text_confidence,
            "image_confidence": image_confidence,
            "source_count": {
                "text": len(text_analyses),
                "images": len(image_analyses)
            }
        }
        
        return result