ai_agents_sustainable / utils /token_manager.py
Chamin09's picture
initial commit
7de43ca verified
# utils/token_manager.py
import logging
from typing import Dict, Optional, Tuple, Any
from transformers import AutoTokenizer
class TokenManager:
def __init__(self, config: Optional[Dict] = None):
"""Initialize the TokenManager with optional configuration."""
self.config = config or {}
self.token_counters = {} # Track usage by agent/model
self.token_budgets = self.config.get('budgets', {})
self.tokenizer_cache = {} # Cache tokenizers for efficiency
self.logger = logging.getLogger(__name__)
# Default budgets if not specified in config
if not self.token_budgets:
self.token_budgets = {
'text_analysis': 5000,
'image_captioning': 3000,
'report_generation': 4000,
'default': 2000
}
def register_model(self, model_name: str, model_type: str) -> None:
"""Register a model and load its tokenizer for accurate token counting."""
if model_name not in self.tokenizer_cache:
try:
self.tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(model_name)
self.logger.info(f"Registered tokenizer for model: {model_name}")
except Exception as e:
self.logger.error(f"Failed to load tokenizer for {model_name}: {e}")
# Fallback to approximate counting
self.tokenizer_cache[model_name] = None
def estimate_tokens(self, text: str, model_name: str) -> int:
"""Estimate token count for given text and model."""
if not text:
return 0
tokenizer = self.tokenizer_cache.get(model_name)
if tokenizer:
# Use model-specific tokenizer for accurate counting
tokens = tokenizer(text, return_tensors="pt")
return tokens.input_ids.shape[1]
else:
# Fallback: approximate token count (4 chars ≈ 1 token)
return len(text) // 4 + 1
def request_tokens(self, agent_name: str, operation_type: str,
text: str, model_name: str) -> Tuple[bool, str]:
"""Request token budget for an operation. Returns (approved, reason)."""
# Get budget for this operation type
budget = self.token_budgets.get(operation_type,
self.token_budgets.get('default', 1000))
# Estimate token usage
estimated_tokens = self.estimate_tokens(text, model_name)
# Check if within budget
if estimated_tokens > budget:
reason = f"Token budget exceeded: {estimated_tokens} > {budget}"
return False, reason
# Initialize counter if needed
if agent_name not in self.token_counters:
self.token_counters[agent_name] = {}
# Approve request
return True, "Token budget approved"
def log_usage(self, agent_name: str, operation_type: str,
token_count: int, model_name: str) -> None:
"""Log actual token usage after operation."""
if agent_name not in self.token_counters:
self.token_counters[agent_name] = {}
if operation_type not in self.token_counters[agent_name]:
self.token_counters[agent_name][operation_type] = 0
self.token_counters[agent_name][operation_type] += token_count
self.logger.info(f"Logged {token_count} tokens for {agent_name}.{operation_type}")
def get_usage_stats(self) -> Dict[str, Any]:
"""Return current token usage statistics."""
total_usage = 0
for agent, operations in self.token_counters.items():
agent_total = sum(operations.values())
total_usage += agent_total
return {
'by_agent': self.token_counters,
'total_usage': total_usage,
'budgets': self.token_budgets
}
def optimize_prompt(self, prompt: str, model_name: str,
max_tokens: Optional[int] = None) -> str:
"""Apply token optimization techniques to prompt."""
if not max_tokens:
return prompt
tokenizer = self.tokenizer_cache.get(model_name)
if not tokenizer:
# Can't optimize without tokenizer
return prompt
# Check current token count
current_tokens = self.estimate_tokens(prompt, model_name)
if current_tokens <= max_tokens:
return prompt
# Simple truncation strategy (as a basic implementation)
# In a real system, we'd use more sophisticated techniques
tokens = tokenizer(prompt, return_tensors="pt").input_ids[0]
truncated_tokens = tokens[:max_tokens]
# Decode back to text
optimized_prompt = tokenizer.decode(truncated_tokens)
self.logger.info(f"Optimized prompt from {current_tokens} to {max_tokens} tokens")
return optimized_prompt
def calculate_energy_usage(self, token_count: int, model_name: str) -> float:
"""
Calculate approximate energy usage based on token count and model.
Returns energy usage in watt-hours.
"""
# Model energy coefficients (approximate watt-hours per 1K tokens)
# Based on research estimates for different model sizes
energy_coefficients = {
# Small models
'sentence-transformers/all-MiniLM-L6-v2': 0.0001,
'microsoft/deberta-v3-small': 0.0005,
'google/flan-t5-small': 0.0007,
# Medium models
'Salesforce/blip-image-captioning-base': 0.003,
't5-small': 0.001,
# Large models
'Salesforce/BLIP-2': 0.015,
# Default for unknown models (conservative estimate)
'default': 0.005
}
# Get coefficient for this model or use default
coefficient = energy_coefficients.get(model_name, energy_coefficients['default'])
# Calculate energy (convert tokens to thousands)
energy_usage = (token_count / 1000) * coefficient
self.logger.info(f"Estimated energy usage for {token_count} tokens with {model_name}: {energy_usage:.6f} watt-hours")
return energy_usage
def adjust_budget(self, operation_type: str, new_budget: int) -> None:
"""
Dynamically adjust token budget for an operation type.
This allows for runtime optimization based on task priority or resource constraints.
"""
if new_budget <= 0:
self.logger.warning(f"Invalid budget value: {new_budget}. Budget must be positive.")
return
old_budget = self.token_budgets.get(operation_type, self.token_budgets.get('default', 0))
self.token_budgets[operation_type] = new_budget
# Log the change
change_percent = ((new_budget - old_budget) / old_budget * 100) if old_budget else 100
self.logger.info(f"Adjusted budget for {operation_type}: {old_budget}{new_budget} tokens ({change_percent:.1f}% change)")
# If this is a significant reduction, we might want to notify dependent systems
if new_budget < old_budget * 0.8: # More than 20% reduction
self.logger.warning(f"Significant budget reduction for {operation_type}. Dependent operations may be affected.")