Spaces:
Sleeping
Sleeping
# utils/token_manager.py | |
import logging | |
from typing import Dict, Optional, Tuple, Any | |
from transformers import AutoTokenizer | |
class TokenManager: | |
def __init__(self, config: Optional[Dict] = None): | |
"""Initialize the TokenManager with optional configuration.""" | |
self.config = config or {} | |
self.token_counters = {} # Track usage by agent/model | |
self.token_budgets = self.config.get('budgets', {}) | |
self.tokenizer_cache = {} # Cache tokenizers for efficiency | |
self.logger = logging.getLogger(__name__) | |
# Default budgets if not specified in config | |
if not self.token_budgets: | |
self.token_budgets = { | |
'text_analysis': 5000, | |
'image_captioning': 3000, | |
'report_generation': 4000, | |
'default': 2000 | |
} | |
def register_model(self, model_name: str, model_type: str) -> None: | |
"""Register a model and load its tokenizer for accurate token counting.""" | |
if model_name not in self.tokenizer_cache: | |
try: | |
self.tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(model_name) | |
self.logger.info(f"Registered tokenizer for model: {model_name}") | |
except Exception as e: | |
self.logger.error(f"Failed to load tokenizer for {model_name}: {e}") | |
# Fallback to approximate counting | |
self.tokenizer_cache[model_name] = None | |
def estimate_tokens(self, text: str, model_name: str) -> int: | |
"""Estimate token count for given text and model.""" | |
if not text: | |
return 0 | |
tokenizer = self.tokenizer_cache.get(model_name) | |
if tokenizer: | |
# Use model-specific tokenizer for accurate counting | |
tokens = tokenizer(text, return_tensors="pt") | |
return tokens.input_ids.shape[1] | |
else: | |
# Fallback: approximate token count (4 chars ≈ 1 token) | |
return len(text) // 4 + 1 | |
def request_tokens(self, agent_name: str, operation_type: str, | |
text: str, model_name: str) -> Tuple[bool, str]: | |
"""Request token budget for an operation. Returns (approved, reason).""" | |
# Get budget for this operation type | |
budget = self.token_budgets.get(operation_type, | |
self.token_budgets.get('default', 1000)) | |
# Estimate token usage | |
estimated_tokens = self.estimate_tokens(text, model_name) | |
# Check if within budget | |
if estimated_tokens > budget: | |
reason = f"Token budget exceeded: {estimated_tokens} > {budget}" | |
return False, reason | |
# Initialize counter if needed | |
if agent_name not in self.token_counters: | |
self.token_counters[agent_name] = {} | |
# Approve request | |
return True, "Token budget approved" | |
def log_usage(self, agent_name: str, operation_type: str, | |
token_count: int, model_name: str) -> None: | |
"""Log actual token usage after operation.""" | |
if agent_name not in self.token_counters: | |
self.token_counters[agent_name] = {} | |
if operation_type not in self.token_counters[agent_name]: | |
self.token_counters[agent_name][operation_type] = 0 | |
self.token_counters[agent_name][operation_type] += token_count | |
self.logger.info(f"Logged {token_count} tokens for {agent_name}.{operation_type}") | |
def get_usage_stats(self) -> Dict[str, Any]: | |
"""Return current token usage statistics.""" | |
total_usage = 0 | |
for agent, operations in self.token_counters.items(): | |
agent_total = sum(operations.values()) | |
total_usage += agent_total | |
return { | |
'by_agent': self.token_counters, | |
'total_usage': total_usage, | |
'budgets': self.token_budgets | |
} | |
def optimize_prompt(self, prompt: str, model_name: str, | |
max_tokens: Optional[int] = None) -> str: | |
"""Apply token optimization techniques to prompt.""" | |
if not max_tokens: | |
return prompt | |
tokenizer = self.tokenizer_cache.get(model_name) | |
if not tokenizer: | |
# Can't optimize without tokenizer | |
return prompt | |
# Check current token count | |
current_tokens = self.estimate_tokens(prompt, model_name) | |
if current_tokens <= max_tokens: | |
return prompt | |
# Simple truncation strategy (as a basic implementation) | |
# In a real system, we'd use more sophisticated techniques | |
tokens = tokenizer(prompt, return_tensors="pt").input_ids[0] | |
truncated_tokens = tokens[:max_tokens] | |
# Decode back to text | |
optimized_prompt = tokenizer.decode(truncated_tokens) | |
self.logger.info(f"Optimized prompt from {current_tokens} to {max_tokens} tokens") | |
return optimized_prompt | |
def calculate_energy_usage(self, token_count: int, model_name: str) -> float: | |
""" | |
Calculate approximate energy usage based on token count and model. | |
Returns energy usage in watt-hours. | |
""" | |
# Model energy coefficients (approximate watt-hours per 1K tokens) | |
# Based on research estimates for different model sizes | |
energy_coefficients = { | |
# Small models | |
'sentence-transformers/all-MiniLM-L6-v2': 0.0001, | |
'microsoft/deberta-v3-small': 0.0005, | |
'google/flan-t5-small': 0.0007, | |
# Medium models | |
'Salesforce/blip-image-captioning-base': 0.003, | |
't5-small': 0.001, | |
# Large models | |
'Salesforce/BLIP-2': 0.015, | |
# Default for unknown models (conservative estimate) | |
'default': 0.005 | |
} | |
# Get coefficient for this model or use default | |
coefficient = energy_coefficients.get(model_name, energy_coefficients['default']) | |
# Calculate energy (convert tokens to thousands) | |
energy_usage = (token_count / 1000) * coefficient | |
self.logger.info(f"Estimated energy usage for {token_count} tokens with {model_name}: {energy_usage:.6f} watt-hours") | |
return energy_usage | |
def adjust_budget(self, operation_type: str, new_budget: int) -> None: | |
""" | |
Dynamically adjust token budget for an operation type. | |
This allows for runtime optimization based on task priority or resource constraints. | |
""" | |
if new_budget <= 0: | |
self.logger.warning(f"Invalid budget value: {new_budget}. Budget must be positive.") | |
return | |
old_budget = self.token_budgets.get(operation_type, self.token_budgets.get('default', 0)) | |
self.token_budgets[operation_type] = new_budget | |
# Log the change | |
change_percent = ((new_budget - old_budget) / old_budget * 100) if old_budget else 100 | |
self.logger.info(f"Adjusted budget for {operation_type}: {old_budget} → {new_budget} tokens ({change_percent:.1f}% change)") | |
# If this is a significant reduction, we might want to notify dependent systems | |
if new_budget < old_budget * 0.8: # More than 20% reduction | |
self.logger.warning(f"Significant budget reduction for {operation_type}. Dependent operations may be affected.") | |