ai_agents_sustainable / models /summary_models.py
Chamin09's picture
Update models/summary_models.py
a40ee8d verified
# models/summary_models.py
import logging
from typing import Dict, List, Optional, Tuple, Union, Any
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
class SummaryModelManager:
def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None):
"""Initialize the SummaryModelManager with optional utilities."""
self.logger = logging.getLogger(__name__)
self.token_manager = token_manager
self.cache_manager = cache_manager
self.metrics_calculator = metrics_calculator
# Model instance
self.model = None
self.tokenizer = None
# Model name
self.model_name = "t5-small"
# Track initialization state
self.initialized = False
# Default generation parameters
self.default_params = {
"max_length": 150,
"min_length": 40,
"length_penalty": 2.0,
"num_beams": 4,
"early_stopping": True
}
def initialize_model(self):
"""Initialize the summarization model."""
if self.initialized:
return
try:
# Register with token manager if available
if self.token_manager:
self.token_manager.register_model(
self.model_name, "summarization")
# Load model and tokenizer
self.logger.info(f"Loading summary model: {self.model_name}")
self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
self.initialized = True
self.logger.info("Summary model initialized successfully")
except Exception as e:
#self.logger.error(f"Failed to initialize summary model: {e}")
#raise
# Try a fallback model that doesn't require SentencePiece
try:
fallback_model = "facebook/bart-base"
self.logger.info(f"Trying fallback model: {fallback_model}")
from transformers import BartTokenizer, BartForConditionalGeneration
self.tokenizer = BartTokenizer.from_pretrained(fallback_model)
self.model = BartForConditionalGeneration.from_pretrained(fallback_model)
self.model_name = fallback_model
# Register fallback with token manager
if self.token_manager:
self.token_manager.register_model(
self.model_name, "summarization")
self.initialized = True
self.logger.info("Fallback summary model initialized successfully")
except Exception as fallback_error:
self.logger.error(f"Failed to initialize fallback model: {fallback_error}")
raise
def generate_summary(self, text: str, prefix: str = "summarize: ",
agent_name: str = "report_generation",
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Generate a summary of the given text.
Returns the summary and metadata.
"""
# Initialize model if needed
if not self.initialized:
self.initialize_model()
# Prepare input text
input_text = f"{prefix}{text}"
# Check cache if available
if self.cache_manager:
cache_key = input_text[:100] + str(hash(input_text)) # Use prefix of text + hash as key
cache_hit, cached_result = self.cache_manager.get(
cache_key, namespace="summaries")
if cache_hit:
# Update metrics if available
if self.metrics_calculator:
self.metrics_calculator.update_cache_metrics(1, 0, 0.005) # Estimated energy saving
return cached_result
# Request token budget if available
if self.token_manager:
approved, reason = self.token_manager.request_tokens(
agent_name, "summarization", input_text, self.model_name)
if not approved:
self.logger.warning(f"Token budget exceeded: {reason}")
return {"summary": "Token budget exceeded", "error": reason}
# Tokenize
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Merge default and custom parameters
generation_params = self.default_params.copy()
if params:
generation_params.update(params)
# Generate summary
with torch.no_grad():
output_ids = self.model.generate(
inputs.input_ids,
**generation_params
)
# Decode summary
summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Calculate compression ratio
input_length = len(text.split())
summary_length = len(summary.split())
compression_ratio = input_length / max(summary_length, 1)
# Prepare result
result = {
"summary": summary,
"input_length": input_length,
"summary_length": summary_length,
"compression_ratio": compression_ratio
}
# Log token usage if available
if self.token_manager:
input_tokens = len(inputs.input_ids[0])
output_tokens = len(output_ids[0])
total_tokens = input_tokens + output_tokens
self.token_manager.log_usage(
agent_name, "summarization", total_tokens, self.model_name)
# Log energy usage if metrics calculator is available
if self.metrics_calculator:
energy_usage = self.token_manager.calculate_energy_usage(
total_tokens, self.model_name)
self.metrics_calculator.log_energy_usage(
energy_usage, self.model_name, agent_name, "summarization")
# Store in cache if available
if self.cache_manager:
self.cache_manager.put(cache_key, result, namespace="summaries")
return result
def generate_executive_summary(self, detailed_content: str, confidence_level: float,
agent_name: str = "report_generation") -> Dict[str, Any]:
"""
Generate an executive summary with confidence indication.
Adjusts detail level based on confidence.
"""
# Prepare prompt based on confidence
if confidence_level >= 0.7:
prefix = "summarize with high confidence: "
params = {"min_length": 30, "max_length": 100}
elif confidence_level >= 0.4:
prefix = "summarize with moderate confidence: "
params = {"min_length": 20, "max_length": 80}
else:
prefix = "summarize with low confidence: "
params = {"min_length": 15, "max_length": 60}
# Generate summary
result = self.generate_summary(detailed_content, prefix=prefix,
agent_name=agent_name, params=params)
# Add confidence level to result
result["confidence_level"] = confidence_level
# Add confidence statement
confidence_statement = self._generate_confidence_statement(confidence_level)
result["confidence_statement"] = confidence_statement
return result
def _generate_confidence_statement(self, confidence_level: float) -> str:
"""Generate an appropriate confidence statement based on the level."""
if confidence_level >= 0.8:
return "This analysis is provided with high confidence based on strong evidence in the provided materials."
elif confidence_level >= 0.6:
return "This analysis is provided with good confidence based on substantial evidence in the provided materials."
elif confidence_level >= 0.4:
return "This analysis is provided with moderate confidence. Some aspects may require additional verification."
elif confidence_level >= 0.2:
return "This analysis is provided with limited confidence due to sparse relevant information in the provided materials."
else:
return "This analysis is provided with very low confidence due to insufficient relevant information in the provided materials."
def combine_analyses(self, text_analyses: List[Dict[str, Any]],
image_analyses: List[Dict[str, Any]],
topic: str, agent_name: str = "report_generation") -> Dict[str, Any]:
"""
Combine text and image analyses into a coherent report.
Returns the combined report with metadata.
"""
# Build combined content
combined_content = f"Topic: {topic}\n\n"
# Add text analyses
combined_content += "Text Analysis:\n"
for i, analysis in enumerate(text_analyses):
if "error" in analysis:
continue
combined_content += f"- Document {i+1}: {analysis.get('summary', 'No summary available')}\n"
# Add image analyses
combined_content += "\nImage Analysis:\n"
for i, analysis in enumerate(image_analyses):
if "error" in analysis:
continue
combined_content += f"- Image {i+1}: {analysis.get('caption', 'No caption available')}\n"
# Calculate overall confidence based on analyses
text_confidence = sum(a.get("confidence", 0) for a in text_analyses) / max(len(text_analyses), 1)
image_confidence = sum(a.get("confidence", 0) for a in image_analyses) / max(len(image_analyses), 1)
# Weight confidence (text analyses typically more important for deep dives)
overall_confidence = 0.7 * text_confidence + 0.3 * image_confidence
# Generate detailed report
detailed_report = self.generate_summary(
combined_content,
prefix=f"generate detailed report about {topic}: ",
agent_name=agent_name,
params={"max_length": 300, "min_length": 100}
)
# Generate executive summary
executive_summary = self.generate_executive_summary(
detailed_report["summary"],
overall_confidence,
agent_name
)
# Combine results
result = {
"topic": topic,
"executive_summary": executive_summary["summary"],
"confidence_statement": executive_summary["confidence_statement"],
"detailed_report": detailed_report["summary"],
"confidence_level": overall_confidence,
"text_confidence": text_confidence,
"image_confidence": image_confidence,
"source_count": {
"text": len(text_analyses),
"images": len(image_analyses)
}
}
return result