ai_agents_sustainable / utils /cache_manager.py
Chamin09's picture
initial commit
7de43ca verified
# utils/cache_manager.py
import hashlib
import logging
import time
from typing import Dict, Any, Optional, Tuple, List, Union
from datetime import datetime, timedelta
import numpy as np
class CacheManager:
def __init__(self, config: Optional[Dict] = None):
"""Initialize the CacheManager with optional configuration."""
self.config = config or {}
self.logger = logging.getLogger(__name__)
# Main cache storage
self.cache = {}
# Cache statistics
self.stats = {
'hits': 0,
'misses': 0,
'entries': 0,
'evictions': 0
}
# Cache configuration
self.max_entries = self.config.get('max_entries', 1000)
self.ttl = self.config.get('ttl', 3600) # Time to live in seconds
self.semantic_threshold = self.config.get('semantic_threshold', 0.85)
# For semantic caching
self.embedding_cache = {}
def _generate_key(self, data: Union[str, bytes], namespace: str = '') -> str:
"""Generate a cache key for the given data."""
if isinstance(data, str):
data = data.encode('utf-8')
key = hashlib.md5(data).hexdigest()
if namespace:
key = f"{namespace}:{key}"
return key
def get(self, data: str, namespace: str = '') -> Tuple[bool, Any]:
"""
Try to retrieve data from cache.
Returns (hit, value) where hit is a boolean indicating cache hit/miss.
"""
key = self._generate_key(data, namespace)
# Check for exact match
if key in self.cache:
entry = self.cache[key]
# Check if entry has expired
if datetime.now() > entry['expiry']:
# Entry expired
del self.cache[key]
self.stats['evictions'] += 1
self.stats['entries'] -= 1
self.stats['misses'] += 1
return False, None
# Update last accessed time
entry['last_accessed'] = datetime.now()
self.stats['hits'] += 1
return True, entry['value']
# No exact match found
self.stats['misses'] += 1
return False, None
def get_semantic(self, data: str, embedding: np.ndarray,
namespace: str = '') -> Tuple[bool, Any]:
"""
Try to retrieve data from cache using semantic similarity.
Requires pre-computed embedding for the query.
"""
# First try exact match
hit, value = self.get(data, namespace)
if hit:
return hit, value
# No exact match, try semantic matching if we have embeddings
if namespace not in self.embedding_cache:
return False, None
# Find closest match
best_similarity = 0
best_key = None
for key, stored_embedding in self.embedding_cache[namespace].items():
similarity = np.dot(embedding, stored_embedding) / (
np.linalg.norm(embedding) * np.linalg.norm(stored_embedding))
if similarity > best_similarity:
best_similarity = similarity
best_key = key
# Check if best match exceeds threshold
if best_similarity >= self.semantic_threshold and best_key in self.cache:
entry = self.cache[best_key]
# Check expiry
if datetime.now() > entry['expiry']:
return False, None
# Update stats and return
self.stats['hits'] += 1
return True, entry['value']
return False, None
def put(self, data: str, value: Any, namespace: str = '',
ttl: Optional[int] = None, embedding: Optional[np.ndarray] = None) -> None:
"""
Store data in cache with optional embedding for semantic search.
"""
# Generate key
key = self._generate_key(data, namespace)
# Check if cache is full
if len(self.cache) >= self.max_entries and key not in self.cache:
self._evict_oldest()
# Set expiry time
expiry = datetime.now() + timedelta(seconds=ttl if ttl is not None else self.ttl)
# Store in cache
self.cache[key] = {
'value': value,
'expiry': expiry,
'last_accessed': datetime.now(),
'access_count': 1
}
# Store embedding if provided
if embedding is not None:
if namespace not in self.embedding_cache:
self.embedding_cache[namespace] = {}
self.embedding_cache[namespace][key] = embedding
# Update stats
if key not in self.cache:
self.stats['entries'] += 1
def _evict_oldest(self) -> None:
"""Evict the least recently used cache entry."""
if not self.cache:
return
oldest_time = datetime.now()
oldest_key = None
for key, entry in self.cache.items():
if entry['last_accessed'] < oldest_time:
oldest_time = entry['last_accessed']
oldest_key = key
if oldest_key:
# Remove from main cache
del self.cache[oldest_key]
# Remove from embedding cache if present
for namespace in self.embedding_cache:
if oldest_key in self.embedding_cache[namespace]:
del self.embedding_cache[namespace][oldest_key]
self.stats['evictions'] += 1
self.stats['entries'] -= 1
def clear(self, namespace: Optional[str] = None) -> None:
"""
Clear the cache, optionally only for a specific namespace.
"""
if namespace:
# Clear only specific namespace
keys_to_remove = []
for key in self.cache:
if key.startswith(f"{namespace}:"):
keys_to_remove.append(key)
for key in keys_to_remove:
del self.cache[key]
self.stats['entries'] -= 1
# Clear embeddings for namespace
if namespace in self.embedding_cache:
del self.embedding_cache[namespace]
else:
# Clear entire cache
self.cache = {}
self.embedding_cache = {}
self.stats['entries'] = 0
self.logger.info(f"Cleared cache{' for namespace: ' + namespace if namespace else ''}")
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
if self.stats['hits'] + self.stats['misses'] > 0:
hit_rate = self.stats['hits'] / (self.stats['hits'] + self.stats['misses'])
else:
hit_rate = 0
return {
**self.stats,
'hit_rate': hit_rate,
'current_size': len(self.cache),
'max_size': self.max_entries
}