Spaces:
Sleeping
Sleeping
# models/text_models.py | |
import logging | |
from typing import Dict, List, Optional, Tuple, Union, Any | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification | |
from sentence_transformers import SentenceTransformer | |
class TextModelManager: | |
def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None): | |
"""Initialize the TextModelManager with optional utilities.""" | |
self.logger = logging.getLogger(__name__) | |
self.token_manager = token_manager | |
self.cache_manager = cache_manager | |
self.metrics_calculator = metrics_calculator | |
# Model instances | |
self.embedding_model = None | |
self.understanding_model = None | |
self.embedding_tokenizer = None | |
self.understanding_tokenizer = None | |
# Model names | |
self.embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
self.understanding_model_name = "microsoft/deberta-v3-small" | |
# Track initialization state | |
self.initialized = { | |
"embedding": False, | |
"understanding": False | |
} | |
def initialize_embedding_model(self): | |
"""Initialize the embedding model for topic-document relevance.""" | |
if self.initialized["embedding"]: | |
return | |
try: | |
# Register with token manager if available | |
if self.token_manager: | |
self.token_manager.register_model( | |
self.embedding_model_name, "embedding") | |
# Load model | |
self.logger.info(f"Loading embedding model: {self.embedding_model_name}") | |
self.embedding_model = SentenceTransformer(self.embedding_model_name) | |
# Also load tokenizer separately for token counting | |
self.embedding_tokenizer = AutoTokenizer.from_pretrained( | |
self.embedding_model_name) | |
self.initialized["embedding"] = True | |
self.logger.info("Embedding model initialized successfully") | |
except Exception as e: | |
self.logger.error(f"Failed to initialize embedding model: {e}") | |
raise | |
def initialize_understanding_model(self): | |
"""Initialize the understanding model for document analysis.""" | |
if self.initialized["understanding"]: | |
return | |
try: | |
model_name = self.understanding_model_name | |
# Register with token manager if available | |
if self.token_manager: | |
self.token_manager.register_model( | |
self.understanding_model_name, "understanding") | |
# Load model and tokenizer | |
self.logger.info(f"Loading understanding model: {self.understanding_model_name}") | |
self.understanding_tokenizer = AutoTokenizer.from_pretrained( | |
self.understanding_model_name) | |
self.understanding_model = AutoModel.from_pretrained( | |
self.understanding_model_name) | |
self.initialized["understanding"] = True | |
self.logger.info("Understanding model initialized successfully") | |
except Exception as e: | |
self.logger.error(f"Failed to initialize understanding model: {e}") | |
# Try fallback model if primary fails | |
try: | |
fallback_model = "distilbert-base-uncased" | |
self.logger.info(f"Trying fallback model: {fallback_model}") | |
self.understanding_tokenizer = AutoTokenizer.from_pretrained(fallback_model) | |
self.understanding_model = AutoModel.from_pretrained(fallback_model) | |
self.understanding_model_name = fallback_model | |
self.initialized["understanding"] = True | |
self.logger.info("Fallback understanding model initialized successfully") | |
except Exception as fallback_error: | |
self.logger.error(f"Failed to initialize fallback model: {fallback_error}") | |
raise | |
def get_embeddings(self, texts: Union[str, List[str]], | |
agent_name: str = "text_analysis") -> np.ndarray: | |
""" | |
Generate embeddings for the given texts. | |
Optimized with caching and token tracking. | |
""" | |
# Initialize model if needed | |
if not self.initialized["embedding"]: | |
self.initialize_embedding_model() | |
# Handle single text | |
if isinstance(texts, str): | |
texts = [texts] | |
results = [] | |
cache_hits = 0 | |
tokens_used = 0 | |
for text in texts: | |
# Check cache if available | |
if self.cache_manager: | |
cache_hit, cached_embedding = self.cache_manager.get( | |
text, namespace="embeddings") | |
if cache_hit: | |
results.append(cached_embedding) | |
cache_hits += 1 | |
continue | |
# Request token budget if available | |
if self.token_manager: | |
approved, reason = self.token_manager.request_tokens( | |
agent_name, "embedding", text, self.embedding_model_name) | |
if not approved: | |
self.logger.warning(f"Token budget exceeded: {reason}") | |
# Return zeros as fallback | |
results.append(np.zeros(384)) # Default embedding dimension | |
continue | |
# Generate embedding | |
with torch.no_grad(): | |
embedding = self.embedding_model.encode(text) | |
# Store in cache if available | |
if self.cache_manager: | |
self.cache_manager.put(text, embedding, namespace="embeddings", | |
embedding=embedding) | |
# Log token usage if available | |
if self.token_manager: | |
token_count = len(self.embedding_tokenizer.encode(text)) | |
self.token_manager.log_usage( | |
agent_name, "embedding", token_count, self.embedding_model_name) | |
tokens_used += token_count | |
# Log energy usage if metrics calculator is available | |
if self.metrics_calculator: | |
energy_usage = self.token_manager.calculate_energy_usage( | |
token_count, self.embedding_model_name) | |
self.metrics_calculator.log_energy_usage( | |
energy_usage, self.embedding_model_name, | |
agent_name, "embedding") | |
results.append(embedding) | |
# Update cache metrics if available | |
if self.cache_manager and self.metrics_calculator: | |
# Estimate energy saved through cache hits | |
if cache_hits > 0 and tokens_used > 0: | |
avg_tokens_per_text = tokens_used / (len(texts) - cache_hits) | |
estimated_tokens_saved = avg_tokens_per_text * cache_hits | |
if self.token_manager: | |
energy_saved = self.token_manager.calculate_energy_usage( | |
estimated_tokens_saved, self.embedding_model_name) | |
self.metrics_calculator.update_cache_metrics( | |
cache_hits, len(texts) - cache_hits, energy_saved) | |
# Return single embedding or list based on input | |
if len(results) == 1 and isinstance(texts, str): | |
return results[0] | |
return np.array(results) | |
def compute_similarity(self, topic: str, documents: List[str], | |
agent_name: str = "text_analysis") -> List[float]: | |
""" | |
Compute semantic similarity between topic and documents. | |
Returns list of similarity scores (0-1). | |
""" | |
# Get embeddings | |
topic_embedding = self.get_embeddings(topic, agent_name) | |
doc_embeddings = self.get_embeddings(documents, agent_name) | |
# Compute similarities | |
similarities = [] | |
for doc_embedding in doc_embeddings: | |
# Cosine similarity | |
similarity = np.dot(topic_embedding, doc_embedding) / ( | |
np.linalg.norm(topic_embedding) * np.linalg.norm(doc_embedding)) | |
similarities.append(float(similarity)) | |
return similarities | |
def analyze_document(self, document: str, query: str = None, | |
agent_name: str = "text_analysis") -> Dict[str, Any]: | |
""" | |
Analyze document content using the understanding model. | |
If query is provided, focuses analysis on that query. | |
""" | |
# Initialize model if needed | |
if not self.initialized["understanding"]: | |
self.initialize_understanding_model() | |
# Check cache if available | |
cache_key = f"{document}::{query}" if query else document | |
if self.cache_manager: | |
cache_hit, cached_result = self.cache_manager.get( | |
cache_key, namespace="document_analysis") | |
if cache_hit: | |
# Update metrics if available | |
if self.metrics_calculator: | |
self.metrics_calculator.update_cache_metrics(1, 0, 0.005) # Estimated energy saving | |
return cached_result | |
# Prepare input | |
if query: | |
input_text = f"Query: {query}\nDocument: {document}" | |
else: | |
input_text = document | |
# Request token budget if available | |
if self.token_manager: | |
approved, reason = self.token_manager.request_tokens( | |
agent_name, "understanding", input_text, self.understanding_model_name) | |
if not approved: | |
self.logger.warning(f"Token budget exceeded: {reason}") | |
return {"error": reason, "summary": "Token budget exceeded"} | |
# Tokenize | |
inputs = self.understanding_tokenizer( | |
input_text, return_tensors="pt", truncation=True, max_length=512) | |
# Get model outputs | |
with torch.no_grad(): | |
outputs = self.understanding_model(**inputs) | |
# Process outputs - using last hidden state for analysis | |
last_hidden_state = outputs.last_hidden_state | |
# Extract key information (simplified for demonstration) | |
# In a real implementation, we'd use more sophisticated analysis | |
mean_embedding = torch.mean(last_hidden_state, dim=1).squeeze().numpy() | |
# Create analysis result | |
result = { | |
"document_length": len(document.split()), | |
"embedding_norm": float(np.linalg.norm(mean_embedding)), | |
"content_vector": mean_embedding.tolist()[:10] # First 10 dims as sample | |
} | |
# Log token usage if available | |
if self.token_manager: | |
token_count = len(inputs.input_ids[0]) | |
self.token_manager.log_usage( | |
agent_name, "understanding", token_count, self.understanding_model_name) | |
# Log energy usage if metrics calculator is available | |
if self.metrics_calculator: | |
energy_usage = self.token_manager.calculate_energy_usage( | |
token_count, self.understanding_model_name) | |
self.metrics_calculator.log_energy_usage( | |
energy_usage, self.understanding_model_name, | |
agent_name, "understanding") | |
# Store in cache if available | |
if self.cache_manager: | |
self.cache_manager.put(cache_key, result, namespace="document_analysis") | |
return result | |