Spaces:
Sleeping
Sleeping
# models/summary_models.py | |
import logging | |
from typing import Dict, List, Optional, Tuple, Union, Any | |
import torch | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
class SummaryModelManager: | |
def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None): | |
"""Initialize the SummaryModelManager with optional utilities.""" | |
self.logger = logging.getLogger(__name__) | |
self.token_manager = token_manager | |
self.cache_manager = cache_manager | |
self.metrics_calculator = metrics_calculator | |
# Model instance | |
self.model = None | |
self.tokenizer = None | |
# Model name | |
self.model_name = "t5-small" | |
# Track initialization state | |
self.initialized = False | |
# Default generation parameters | |
self.default_params = { | |
"max_length": 150, | |
"min_length": 40, | |
"length_penalty": 2.0, | |
"num_beams": 4, | |
"early_stopping": True | |
} | |
def initialize_model(self): | |
"""Initialize the summarization model.""" | |
if self.initialized: | |
return | |
try: | |
# Register with token manager if available | |
if self.token_manager: | |
self.token_manager.register_model( | |
self.model_name, "summarization") | |
# Load model and tokenizer | |
self.logger.info(f"Loading summary model: {self.model_name}") | |
self.tokenizer = T5Tokenizer.from_pretrained(self.model_name) | |
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name) | |
self.initialized = True | |
self.logger.info("Summary model initialized successfully") | |
except Exception as e: | |
#self.logger.error(f"Failed to initialize summary model: {e}") | |
#raise | |
# Try a fallback model that doesn't require SentencePiece | |
try: | |
fallback_model = "facebook/bart-base" | |
self.logger.info(f"Trying fallback model: {fallback_model}") | |
from transformers import BartTokenizer, BartForConditionalGeneration | |
self.tokenizer = BartTokenizer.from_pretrained(fallback_model) | |
self.model = BartForConditionalGeneration.from_pretrained(fallback_model) | |
self.model_name = fallback_model | |
# Register fallback with token manager | |
if self.token_manager: | |
self.token_manager.register_model( | |
self.model_name, "summarization") | |
self.initialized = True | |
self.logger.info("Fallback summary model initialized successfully") | |
except Exception as fallback_error: | |
self.logger.error(f"Failed to initialize fallback model: {fallback_error}") | |
raise | |
def generate_summary(self, text: str, prefix: str = "summarize: ", | |
agent_name: str = "report_generation", | |
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: | |
""" | |
Generate a summary of the given text. | |
Returns the summary and metadata. | |
""" | |
# Initialize model if needed | |
if not self.initialized: | |
self.initialize_model() | |
# Prepare input text | |
input_text = f"{prefix}{text}" | |
# Check cache if available | |
if self.cache_manager: | |
cache_key = input_text[:100] + str(hash(input_text)) # Use prefix of text + hash as key | |
cache_hit, cached_result = self.cache_manager.get( | |
cache_key, namespace="summaries") | |
if cache_hit: | |
# Update metrics if available | |
if self.metrics_calculator: | |
self.metrics_calculator.update_cache_metrics(1, 0, 0.005) # Estimated energy saving | |
return cached_result | |
# Request token budget if available | |
if self.token_manager: | |
approved, reason = self.token_manager.request_tokens( | |
agent_name, "summarization", input_text, self.model_name) | |
if not approved: | |
self.logger.warning(f"Token budget exceeded: {reason}") | |
return {"summary": "Token budget exceeded", "error": reason} | |
# Tokenize | |
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
# Merge default and custom parameters | |
generation_params = self.default_params.copy() | |
if params: | |
generation_params.update(params) | |
# Generate summary | |
with torch.no_grad(): | |
output_ids = self.model.generate( | |
inputs.input_ids, | |
**generation_params | |
) | |
# Decode summary | |
summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# Calculate compression ratio | |
input_length = len(text.split()) | |
summary_length = len(summary.split()) | |
compression_ratio = input_length / max(summary_length, 1) | |
# Prepare result | |
result = { | |
"summary": summary, | |
"input_length": input_length, | |
"summary_length": summary_length, | |
"compression_ratio": compression_ratio | |
} | |
# Log token usage if available | |
if self.token_manager: | |
input_tokens = len(inputs.input_ids[0]) | |
output_tokens = len(output_ids[0]) | |
total_tokens = input_tokens + output_tokens | |
self.token_manager.log_usage( | |
agent_name, "summarization", total_tokens, self.model_name) | |
# Log energy usage if metrics calculator is available | |
if self.metrics_calculator: | |
energy_usage = self.token_manager.calculate_energy_usage( | |
total_tokens, self.model_name) | |
self.metrics_calculator.log_energy_usage( | |
energy_usage, self.model_name, agent_name, "summarization") | |
# Store in cache if available | |
if self.cache_manager: | |
self.cache_manager.put(cache_key, result, namespace="summaries") | |
return result | |
def generate_executive_summary(self, detailed_content: str, confidence_level: float, | |
agent_name: str = "report_generation") -> Dict[str, Any]: | |
""" | |
Generate an executive summary with confidence indication. | |
Adjusts detail level based on confidence. | |
""" | |
# Prepare prompt based on confidence | |
if confidence_level >= 0.7: | |
prefix = "summarize with high confidence: " | |
params = {"min_length": 30, "max_length": 100} | |
elif confidence_level >= 0.4: | |
prefix = "summarize with moderate confidence: " | |
params = {"min_length": 20, "max_length": 80} | |
else: | |
prefix = "summarize with low confidence: " | |
params = {"min_length": 15, "max_length": 60} | |
# Generate summary | |
result = self.generate_summary(detailed_content, prefix=prefix, | |
agent_name=agent_name, params=params) | |
# Add confidence level to result | |
result["confidence_level"] = confidence_level | |
# Add confidence statement | |
confidence_statement = self._generate_confidence_statement(confidence_level) | |
result["confidence_statement"] = confidence_statement | |
return result | |
def _generate_confidence_statement(self, confidence_level: float) -> str: | |
"""Generate an appropriate confidence statement based on the level.""" | |
if confidence_level >= 0.8: | |
return "This analysis is provided with high confidence based on strong evidence in the provided materials." | |
elif confidence_level >= 0.6: | |
return "This analysis is provided with good confidence based on substantial evidence in the provided materials." | |
elif confidence_level >= 0.4: | |
return "This analysis is provided with moderate confidence. Some aspects may require additional verification." | |
elif confidence_level >= 0.2: | |
return "This analysis is provided with limited confidence due to sparse relevant information in the provided materials." | |
else: | |
return "This analysis is provided with very low confidence due to insufficient relevant information in the provided materials." | |
def combine_analyses(self, text_analyses: List[Dict[str, Any]], | |
image_analyses: List[Dict[str, Any]], | |
topic: str, agent_name: str = "report_generation") -> Dict[str, Any]: | |
""" | |
Combine text and image analyses into a coherent report. | |
Returns the combined report with metadata. | |
""" | |
# Build combined content | |
combined_content = f"Topic: {topic}\n\n" | |
# Add text analyses | |
combined_content += "Text Analysis:\n" | |
for i, analysis in enumerate(text_analyses): | |
if "error" in analysis: | |
continue | |
combined_content += f"- Document {i+1}: {analysis.get('summary', 'No summary available')}\n" | |
# Add image analyses | |
combined_content += "\nImage Analysis:\n" | |
for i, analysis in enumerate(image_analyses): | |
if "error" in analysis: | |
continue | |
combined_content += f"- Image {i+1}: {analysis.get('caption', 'No caption available')}\n" | |
# Calculate overall confidence based on analyses | |
text_confidence = sum(a.get("confidence", 0) for a in text_analyses) / max(len(text_analyses), 1) | |
image_confidence = sum(a.get("confidence", 0) for a in image_analyses) / max(len(image_analyses), 1) | |
# Weight confidence (text analyses typically more important for deep dives) | |
overall_confidence = 0.7 * text_confidence + 0.3 * image_confidence | |
# Generate detailed report | |
detailed_report = self.generate_summary( | |
combined_content, | |
prefix=f"generate detailed report about {topic}: ", | |
agent_name=agent_name, | |
params={"max_length": 300, "min_length": 100} | |
) | |
# Generate executive summary | |
executive_summary = self.generate_executive_summary( | |
detailed_report["summary"], | |
overall_confidence, | |
agent_name | |
) | |
# Combine results | |
result = { | |
"topic": topic, | |
"executive_summary": executive_summary["summary"], | |
"confidence_statement": executive_summary["confidence_statement"], | |
"detailed_report": detailed_report["summary"], | |
"confidence_level": overall_confidence, | |
"text_confidence": text_confidence, | |
"image_confidence": image_confidence, | |
"source_count": { | |
"text": len(text_analyses), | |
"images": len(image_analyses) | |
} | |
} | |
return result | |