# models/summary_models.py import logging from typing import Dict, List, Optional, Tuple, Union, Any import torch from transformers import T5Tokenizer, T5ForConditionalGeneration class SummaryModelManager: def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None): """Initialize the SummaryModelManager with optional utilities.""" self.logger = logging.getLogger(__name__) self.token_manager = token_manager self.cache_manager = cache_manager self.metrics_calculator = metrics_calculator # Model instance self.model = None self.tokenizer = None # Model name self.model_name = "t5-small" # Track initialization state self.initialized = False # Default generation parameters self.default_params = { "max_length": 150, "min_length": 40, "length_penalty": 2.0, "num_beams": 4, "early_stopping": True } def initialize_model(self): """Initialize the summarization model.""" if self.initialized: return try: # Register with token manager if available if self.token_manager: self.token_manager.register_model( self.model_name, "summarization") # Load model and tokenizer self.logger.info(f"Loading summary model: {self.model_name}") self.tokenizer = T5Tokenizer.from_pretrained(self.model_name) self.model = T5ForConditionalGeneration.from_pretrained(self.model_name) self.initialized = True self.logger.info("Summary model initialized successfully") except Exception as e: #self.logger.error(f"Failed to initialize summary model: {e}") #raise # Try a fallback model that doesn't require SentencePiece try: fallback_model = "facebook/bart-base" self.logger.info(f"Trying fallback model: {fallback_model}") from transformers import BartTokenizer, BartForConditionalGeneration self.tokenizer = BartTokenizer.from_pretrained(fallback_model) self.model = BartForConditionalGeneration.from_pretrained(fallback_model) self.model_name = fallback_model # Register fallback with token manager if self.token_manager: self.token_manager.register_model( self.model_name, "summarization") self.initialized = True self.logger.info("Fallback summary model initialized successfully") except Exception as fallback_error: self.logger.error(f"Failed to initialize fallback model: {fallback_error}") raise def generate_summary(self, text: str, prefix: str = "summarize: ", agent_name: str = "report_generation", params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ Generate a summary of the given text. Returns the summary and metadata. """ # Initialize model if needed if not self.initialized: self.initialize_model() # Prepare input text input_text = f"{prefix}{text}" # Check cache if available if self.cache_manager: cache_key = input_text[:100] + str(hash(input_text)) # Use prefix of text + hash as key cache_hit, cached_result = self.cache_manager.get( cache_key, namespace="summaries") if cache_hit: # Update metrics if available if self.metrics_calculator: self.metrics_calculator.update_cache_metrics(1, 0, 0.005) # Estimated energy saving return cached_result # Request token budget if available if self.token_manager: approved, reason = self.token_manager.request_tokens( agent_name, "summarization", input_text, self.model_name) if not approved: self.logger.warning(f"Token budget exceeded: {reason}") return {"summary": "Token budget exceeded", "error": reason} # Tokenize inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) # Merge default and custom parameters generation_params = self.default_params.copy() if params: generation_params.update(params) # Generate summary with torch.no_grad(): output_ids = self.model.generate( inputs.input_ids, **generation_params ) # Decode summary summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) # Calculate compression ratio input_length = len(text.split()) summary_length = len(summary.split()) compression_ratio = input_length / max(summary_length, 1) # Prepare result result = { "summary": summary, "input_length": input_length, "summary_length": summary_length, "compression_ratio": compression_ratio } # Log token usage if available if self.token_manager: input_tokens = len(inputs.input_ids[0]) output_tokens = len(output_ids[0]) total_tokens = input_tokens + output_tokens self.token_manager.log_usage( agent_name, "summarization", total_tokens, self.model_name) # Log energy usage if metrics calculator is available if self.metrics_calculator: energy_usage = self.token_manager.calculate_energy_usage( total_tokens, self.model_name) self.metrics_calculator.log_energy_usage( energy_usage, self.model_name, agent_name, "summarization") # Store in cache if available if self.cache_manager: self.cache_manager.put(cache_key, result, namespace="summaries") return result def generate_executive_summary(self, detailed_content: str, confidence_level: float, agent_name: str = "report_generation") -> Dict[str, Any]: """ Generate an executive summary with confidence indication. Adjusts detail level based on confidence. """ # Prepare prompt based on confidence if confidence_level >= 0.7: prefix = "summarize with high confidence: " params = {"min_length": 30, "max_length": 100} elif confidence_level >= 0.4: prefix = "summarize with moderate confidence: " params = {"min_length": 20, "max_length": 80} else: prefix = "summarize with low confidence: " params = {"min_length": 15, "max_length": 60} # Generate summary result = self.generate_summary(detailed_content, prefix=prefix, agent_name=agent_name, params=params) # Add confidence level to result result["confidence_level"] = confidence_level # Add confidence statement confidence_statement = self._generate_confidence_statement(confidence_level) result["confidence_statement"] = confidence_statement return result def _generate_confidence_statement(self, confidence_level: float) -> str: """Generate an appropriate confidence statement based on the level.""" if confidence_level >= 0.8: return "This analysis is provided with high confidence based on strong evidence in the provided materials." elif confidence_level >= 0.6: return "This analysis is provided with good confidence based on substantial evidence in the provided materials." elif confidence_level >= 0.4: return "This analysis is provided with moderate confidence. Some aspects may require additional verification." elif confidence_level >= 0.2: return "This analysis is provided with limited confidence due to sparse relevant information in the provided materials." else: return "This analysis is provided with very low confidence due to insufficient relevant information in the provided materials." def combine_analyses(self, text_analyses: List[Dict[str, Any]], image_analyses: List[Dict[str, Any]], topic: str, agent_name: str = "report_generation") -> Dict[str, Any]: """ Combine text and image analyses into a coherent report. Returns the combined report with metadata. """ # Build combined content combined_content = f"Topic: {topic}\n\n" # Add text analyses combined_content += "Text Analysis:\n" for i, analysis in enumerate(text_analyses): if "error" in analysis: continue combined_content += f"- Document {i+1}: {analysis.get('summary', 'No summary available')}\n" # Add image analyses combined_content += "\nImage Analysis:\n" for i, analysis in enumerate(image_analyses): if "error" in analysis: continue combined_content += f"- Image {i+1}: {analysis.get('caption', 'No caption available')}\n" # Calculate overall confidence based on analyses text_confidence = sum(a.get("confidence", 0) for a in text_analyses) / max(len(text_analyses), 1) image_confidence = sum(a.get("confidence", 0) for a in image_analyses) / max(len(image_analyses), 1) # Weight confidence (text analyses typically more important for deep dives) overall_confidence = 0.7 * text_confidence + 0.3 * image_confidence # Generate detailed report detailed_report = self.generate_summary( combined_content, prefix=f"generate detailed report about {topic}: ", agent_name=agent_name, params={"max_length": 300, "min_length": 100} ) # Generate executive summary executive_summary = self.generate_executive_summary( detailed_report["summary"], overall_confidence, agent_name ) # Combine results result = { "topic": topic, "executive_summary": executive_summary["summary"], "confidence_statement": executive_summary["confidence_statement"], "detailed_report": detailed_report["summary"], "confidence_level": overall_confidence, "text_confidence": text_confidence, "image_confidence": image_confidence, "source_count": { "text": len(text_analyses), "images": len(image_analyses) } } return result