ai_agents_sustainable / models /text_models.py
Chamin09's picture
Update models/text_models.py
543e824 verified
# models/text_models.py
import logging
from typing import Dict, List, Optional, Tuple, Union, Any
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer
class TextModelManager:
def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None):
"""Initialize the TextModelManager with optional utilities."""
self.logger = logging.getLogger(__name__)
self.token_manager = token_manager
self.cache_manager = cache_manager
self.metrics_calculator = metrics_calculator
# Model instances
self.embedding_model = None
self.understanding_model = None
self.embedding_tokenizer = None
self.understanding_tokenizer = None
# Model names
self.embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
self.understanding_model_name = "microsoft/deberta-v3-small"
# Track initialization state
self.initialized = {
"embedding": False,
"understanding": False
}
def initialize_embedding_model(self):
"""Initialize the embedding model for topic-document relevance."""
if self.initialized["embedding"]:
return
try:
# Register with token manager if available
if self.token_manager:
self.token_manager.register_model(
self.embedding_model_name, "embedding")
# Load model
self.logger.info(f"Loading embedding model: {self.embedding_model_name}")
self.embedding_model = SentenceTransformer(self.embedding_model_name)
# Also load tokenizer separately for token counting
self.embedding_tokenizer = AutoTokenizer.from_pretrained(
self.embedding_model_name)
self.initialized["embedding"] = True
self.logger.info("Embedding model initialized successfully")
except Exception as e:
self.logger.error(f"Failed to initialize embedding model: {e}")
raise
def initialize_understanding_model(self):
"""Initialize the understanding model for document analysis."""
if self.initialized["understanding"]:
return
try:
model_name = self.understanding_model_name
# Register with token manager if available
if self.token_manager:
self.token_manager.register_model(
self.understanding_model_name, "understanding")
# Load model and tokenizer
self.logger.info(f"Loading understanding model: {self.understanding_model_name}")
self.understanding_tokenizer = AutoTokenizer.from_pretrained(
self.understanding_model_name)
self.understanding_model = AutoModel.from_pretrained(
self.understanding_model_name)
self.initialized["understanding"] = True
self.logger.info("Understanding model initialized successfully")
except Exception as e:
self.logger.error(f"Failed to initialize understanding model: {e}")
# Try fallback model if primary fails
try:
fallback_model = "distilbert-base-uncased"
self.logger.info(f"Trying fallback model: {fallback_model}")
self.understanding_tokenizer = AutoTokenizer.from_pretrained(fallback_model)
self.understanding_model = AutoModel.from_pretrained(fallback_model)
self.understanding_model_name = fallback_model
self.initialized["understanding"] = True
self.logger.info("Fallback understanding model initialized successfully")
except Exception as fallback_error:
self.logger.error(f"Failed to initialize fallback model: {fallback_error}")
raise
def get_embeddings(self, texts: Union[str, List[str]],
agent_name: str = "text_analysis") -> np.ndarray:
"""
Generate embeddings for the given texts.
Optimized with caching and token tracking.
"""
# Initialize model if needed
if not self.initialized["embedding"]:
self.initialize_embedding_model()
# Handle single text
if isinstance(texts, str):
texts = [texts]
results = []
cache_hits = 0
tokens_used = 0
for text in texts:
# Check cache if available
if self.cache_manager:
cache_hit, cached_embedding = self.cache_manager.get(
text, namespace="embeddings")
if cache_hit:
results.append(cached_embedding)
cache_hits += 1
continue
# Request token budget if available
if self.token_manager:
approved, reason = self.token_manager.request_tokens(
agent_name, "embedding", text, self.embedding_model_name)
if not approved:
self.logger.warning(f"Token budget exceeded: {reason}")
# Return zeros as fallback
results.append(np.zeros(384)) # Default embedding dimension
continue
# Generate embedding
with torch.no_grad():
embedding = self.embedding_model.encode(text)
# Store in cache if available
if self.cache_manager:
self.cache_manager.put(text, embedding, namespace="embeddings",
embedding=embedding)
# Log token usage if available
if self.token_manager:
token_count = len(self.embedding_tokenizer.encode(text))
self.token_manager.log_usage(
agent_name, "embedding", token_count, self.embedding_model_name)
tokens_used += token_count
# Log energy usage if metrics calculator is available
if self.metrics_calculator:
energy_usage = self.token_manager.calculate_energy_usage(
token_count, self.embedding_model_name)
self.metrics_calculator.log_energy_usage(
energy_usage, self.embedding_model_name,
agent_name, "embedding")
results.append(embedding)
# Update cache metrics if available
if self.cache_manager and self.metrics_calculator:
# Estimate energy saved through cache hits
if cache_hits > 0 and tokens_used > 0:
avg_tokens_per_text = tokens_used / (len(texts) - cache_hits)
estimated_tokens_saved = avg_tokens_per_text * cache_hits
if self.token_manager:
energy_saved = self.token_manager.calculate_energy_usage(
estimated_tokens_saved, self.embedding_model_name)
self.metrics_calculator.update_cache_metrics(
cache_hits, len(texts) - cache_hits, energy_saved)
# Return single embedding or list based on input
if len(results) == 1 and isinstance(texts, str):
return results[0]
return np.array(results)
def compute_similarity(self, topic: str, documents: List[str],
agent_name: str = "text_analysis") -> List[float]:
"""
Compute semantic similarity between topic and documents.
Returns list of similarity scores (0-1).
"""
# Get embeddings
topic_embedding = self.get_embeddings(topic, agent_name)
doc_embeddings = self.get_embeddings(documents, agent_name)
# Compute similarities
similarities = []
for doc_embedding in doc_embeddings:
# Cosine similarity
similarity = np.dot(topic_embedding, doc_embedding) / (
np.linalg.norm(topic_embedding) * np.linalg.norm(doc_embedding))
similarities.append(float(similarity))
return similarities
def analyze_document(self, document: str, query: str = None,
agent_name: str = "text_analysis") -> Dict[str, Any]:
"""
Analyze document content using the understanding model.
If query is provided, focuses analysis on that query.
"""
# Initialize model if needed
if not self.initialized["understanding"]:
self.initialize_understanding_model()
# Check cache if available
cache_key = f"{document}::{query}" if query else document
if self.cache_manager:
cache_hit, cached_result = self.cache_manager.get(
cache_key, namespace="document_analysis")
if cache_hit:
# Update metrics if available
if self.metrics_calculator:
self.metrics_calculator.update_cache_metrics(1, 0, 0.005) # Estimated energy saving
return cached_result
# Prepare input
if query:
input_text = f"Query: {query}\nDocument: {document}"
else:
input_text = document
# Request token budget if available
if self.token_manager:
approved, reason = self.token_manager.request_tokens(
agent_name, "understanding", input_text, self.understanding_model_name)
if not approved:
self.logger.warning(f"Token budget exceeded: {reason}")
return {"error": reason, "summary": "Token budget exceeded"}
# Tokenize
inputs = self.understanding_tokenizer(
input_text, return_tensors="pt", truncation=True, max_length=512)
# Get model outputs
with torch.no_grad():
outputs = self.understanding_model(**inputs)
# Process outputs - using last hidden state for analysis
last_hidden_state = outputs.last_hidden_state
# Extract key information (simplified for demonstration)
# In a real implementation, we'd use more sophisticated analysis
mean_embedding = torch.mean(last_hidden_state, dim=1).squeeze().numpy()
# Create analysis result
result = {
"document_length": len(document.split()),
"embedding_norm": float(np.linalg.norm(mean_embedding)),
"content_vector": mean_embedding.tolist()[:10] # First 10 dims as sample
}
# Log token usage if available
if self.token_manager:
token_count = len(inputs.input_ids[0])
self.token_manager.log_usage(
agent_name, "understanding", token_count, self.understanding_model_name)
# Log energy usage if metrics calculator is available
if self.metrics_calculator:
energy_usage = self.token_manager.calculate_energy_usage(
token_count, self.understanding_model_name)
self.metrics_calculator.log_energy_usage(
energy_usage, self.understanding_model_name,
agent_name, "understanding")
# Store in cache if available
if self.cache_manager:
self.cache_manager.put(cache_key, result, namespace="document_analysis")
return result