# models/text_models.py import logging from typing import Dict, List, Optional, Tuple, Union, Any import numpy as np import torch from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification from sentence_transformers import SentenceTransformer class TextModelManager: def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None): """Initialize the TextModelManager with optional utilities.""" self.logger = logging.getLogger(__name__) self.token_manager = token_manager self.cache_manager = cache_manager self.metrics_calculator = metrics_calculator # Model instances self.embedding_model = None self.understanding_model = None self.embedding_tokenizer = None self.understanding_tokenizer = None # Model names self.embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" self.understanding_model_name = "microsoft/deberta-v3-small" # Track initialization state self.initialized = { "embedding": False, "understanding": False } def initialize_embedding_model(self): """Initialize the embedding model for topic-document relevance.""" if self.initialized["embedding"]: return try: # Register with token manager if available if self.token_manager: self.token_manager.register_model( self.embedding_model_name, "embedding") # Load model self.logger.info(f"Loading embedding model: {self.embedding_model_name}") self.embedding_model = SentenceTransformer(self.embedding_model_name) # Also load tokenizer separately for token counting self.embedding_tokenizer = AutoTokenizer.from_pretrained( self.embedding_model_name) self.initialized["embedding"] = True self.logger.info("Embedding model initialized successfully") except Exception as e: self.logger.error(f"Failed to initialize embedding model: {e}") raise def initialize_understanding_model(self): """Initialize the understanding model for document analysis.""" if self.initialized["understanding"]: return try: model_name = self.understanding_model_name # Register with token manager if available if self.token_manager: self.token_manager.register_model( self.understanding_model_name, "understanding") # Load model and tokenizer self.logger.info(f"Loading understanding model: {self.understanding_model_name}") self.understanding_tokenizer = AutoTokenizer.from_pretrained( self.understanding_model_name) self.understanding_model = AutoModel.from_pretrained( self.understanding_model_name) self.initialized["understanding"] = True self.logger.info("Understanding model initialized successfully") except Exception as e: self.logger.error(f"Failed to initialize understanding model: {e}") # Try fallback model if primary fails try: fallback_model = "distilbert-base-uncased" self.logger.info(f"Trying fallback model: {fallback_model}") self.understanding_tokenizer = AutoTokenizer.from_pretrained(fallback_model) self.understanding_model = AutoModel.from_pretrained(fallback_model) self.understanding_model_name = fallback_model self.initialized["understanding"] = True self.logger.info("Fallback understanding model initialized successfully") except Exception as fallback_error: self.logger.error(f"Failed to initialize fallback model: {fallback_error}") raise def get_embeddings(self, texts: Union[str, List[str]], agent_name: str = "text_analysis") -> np.ndarray: """ Generate embeddings for the given texts. Optimized with caching and token tracking. """ # Initialize model if needed if not self.initialized["embedding"]: self.initialize_embedding_model() # Handle single text if isinstance(texts, str): texts = [texts] results = [] cache_hits = 0 tokens_used = 0 for text in texts: # Check cache if available if self.cache_manager: cache_hit, cached_embedding = self.cache_manager.get( text, namespace="embeddings") if cache_hit: results.append(cached_embedding) cache_hits += 1 continue # Request token budget if available if self.token_manager: approved, reason = self.token_manager.request_tokens( agent_name, "embedding", text, self.embedding_model_name) if not approved: self.logger.warning(f"Token budget exceeded: {reason}") # Return zeros as fallback results.append(np.zeros(384)) # Default embedding dimension continue # Generate embedding with torch.no_grad(): embedding = self.embedding_model.encode(text) # Store in cache if available if self.cache_manager: self.cache_manager.put(text, embedding, namespace="embeddings", embedding=embedding) # Log token usage if available if self.token_manager: token_count = len(self.embedding_tokenizer.encode(text)) self.token_manager.log_usage( agent_name, "embedding", token_count, self.embedding_model_name) tokens_used += token_count # Log energy usage if metrics calculator is available if self.metrics_calculator: energy_usage = self.token_manager.calculate_energy_usage( token_count, self.embedding_model_name) self.metrics_calculator.log_energy_usage( energy_usage, self.embedding_model_name, agent_name, "embedding") results.append(embedding) # Update cache metrics if available if self.cache_manager and self.metrics_calculator: # Estimate energy saved through cache hits if cache_hits > 0 and tokens_used > 0: avg_tokens_per_text = tokens_used / (len(texts) - cache_hits) estimated_tokens_saved = avg_tokens_per_text * cache_hits if self.token_manager: energy_saved = self.token_manager.calculate_energy_usage( estimated_tokens_saved, self.embedding_model_name) self.metrics_calculator.update_cache_metrics( cache_hits, len(texts) - cache_hits, energy_saved) # Return single embedding or list based on input if len(results) == 1 and isinstance(texts, str): return results[0] return np.array(results) def compute_similarity(self, topic: str, documents: List[str], agent_name: str = "text_analysis") -> List[float]: """ Compute semantic similarity between topic and documents. Returns list of similarity scores (0-1). """ # Get embeddings topic_embedding = self.get_embeddings(topic, agent_name) doc_embeddings = self.get_embeddings(documents, agent_name) # Compute similarities similarities = [] for doc_embedding in doc_embeddings: # Cosine similarity similarity = np.dot(topic_embedding, doc_embedding) / ( np.linalg.norm(topic_embedding) * np.linalg.norm(doc_embedding)) similarities.append(float(similarity)) return similarities def analyze_document(self, document: str, query: str = None, agent_name: str = "text_analysis") -> Dict[str, Any]: """ Analyze document content using the understanding model. If query is provided, focuses analysis on that query. """ # Initialize model if needed if not self.initialized["understanding"]: self.initialize_understanding_model() # Check cache if available cache_key = f"{document}::{query}" if query else document if self.cache_manager: cache_hit, cached_result = self.cache_manager.get( cache_key, namespace="document_analysis") if cache_hit: # Update metrics if available if self.metrics_calculator: self.metrics_calculator.update_cache_metrics(1, 0, 0.005) # Estimated energy saving return cached_result # Prepare input if query: input_text = f"Query: {query}\nDocument: {document}" else: input_text = document # Request token budget if available if self.token_manager: approved, reason = self.token_manager.request_tokens( agent_name, "understanding", input_text, self.understanding_model_name) if not approved: self.logger.warning(f"Token budget exceeded: {reason}") return {"error": reason, "summary": "Token budget exceeded"} # Tokenize inputs = self.understanding_tokenizer( input_text, return_tensors="pt", truncation=True, max_length=512) # Get model outputs with torch.no_grad(): outputs = self.understanding_model(**inputs) # Process outputs - using last hidden state for analysis last_hidden_state = outputs.last_hidden_state # Extract key information (simplified for demonstration) # In a real implementation, we'd use more sophisticated analysis mean_embedding = torch.mean(last_hidden_state, dim=1).squeeze().numpy() # Create analysis result result = { "document_length": len(document.split()), "embedding_norm": float(np.linalg.norm(mean_embedding)), "content_vector": mean_embedding.tolist()[:10] # First 10 dims as sample } # Log token usage if available if self.token_manager: token_count = len(inputs.input_ids[0]) self.token_manager.log_usage( agent_name, "understanding", token_count, self.understanding_model_name) # Log energy usage if metrics calculator is available if self.metrics_calculator: energy_usage = self.token_manager.calculate_energy_usage( token_count, self.understanding_model_name) self.metrics_calculator.log_energy_usage( energy_usage, self.understanding_model_name, agent_name, "understanding") # Store in cache if available if self.cache_manager: self.cache_manager.put(cache_key, result, namespace="document_analysis") return result