Spaces:
Runtime error
Runtime error
import functools | |
import hashlib | |
import os | |
import tempfile | |
import time | |
from enum import Enum | |
from typing import Dict, List, Optional, Tuple | |
import numpy as np | |
from nodes.log import logger | |
CACHE_MAX_BYTES = int(os.environ.get("CACHE_MAX_BYTES", 1024**3)) # default 1 GiB | |
CACHE_REGISTRY: List["NodeOutputCache"] = [] | |
class CachedNumpyArray: | |
def __init__(self, arr: np.ndarray): | |
self.file = tempfile.TemporaryFile() | |
self.file.write(arr.tobytes()) | |
self.shape = arr.shape | |
self.dtype = arr.dtype | |
def value(self) -> np.ndarray: | |
self.file.seek(0) | |
return np.frombuffer(self.file.read(), dtype=self.dtype).reshape(self.shape) | |
class NodeOutputCache: | |
def __init__(self): | |
self._data: Dict[Tuple, List] = {} | |
self._bytes: Dict[Tuple, int] = {} | |
self._access_time: Dict[Tuple, float] = {} | |
CACHE_REGISTRY.append(self) | |
def _args_to_key(args) -> Tuple: | |
key = [] | |
for arg in args: | |
if isinstance(arg, (int, float, bool, str, bytes)): | |
key.append(arg) | |
elif arg is None: | |
key.append(None) | |
elif isinstance(arg, Enum): | |
key.append(arg.value) | |
elif isinstance(arg, np.ndarray): | |
key.append(tuple(arg.shape)) | |
key.append(arg.dtype.str) | |
key.append(hashlib.sha256(arg.tobytes()).digest()) | |
elif hasattr(arg, "cache_key_func"): | |
key.append(arg.__class__.__name__) | |
key.append(arg.cache_key_func()) | |
else: | |
raise RuntimeError(f"Unexpected argument type {arg.__class__.__name__}") | |
return tuple(key) | |
def _estimate_bytes(output) -> int: | |
size = 0 | |
for out in output: | |
if isinstance(out, np.ndarray): | |
size += out.nbytes | |
else: | |
# any other type but numpy arrays is probably negligible, but here's an overestimate to handle | |
# pathological cases where someone has a pipeline with a million math nodes | |
size += 1024 # 1 KiB | |
return size | |
def empty(self): | |
return len(self._data) == 0 | |
def oldest(self) -> Tuple[Tuple, float]: | |
return min(self._access_time.items(), key=lambda x: x[1]) | |
def size(self): | |
return sum(self._bytes.values()) | |
def _enforce_limits(): | |
while True: | |
total_bytes = sum([cache.size() for cache in CACHE_REGISTRY]) | |
logger.debug( | |
f"Cache size: {total_bytes} ({100*total_bytes/CACHE_MAX_BYTES:0.1f}% of limit)" | |
) | |
if total_bytes <= CACHE_MAX_BYTES: | |
return | |
logger.debug("Dropping oldest cache key") | |
oldest_keys = [ | |
(cache, cache.oldest()) for cache in CACHE_REGISTRY if not cache.empty() | |
] | |
cache, (key, _) = min(oldest_keys, key=lambda x: x[1][1]) | |
cache.drop(key) | |
def _write_arrays_to_disk(output: List) -> List: | |
return [ | |
CachedNumpyArray(item) if isinstance(item, np.ndarray) else item | |
for item in output | |
] | |
def _read_arrays_from_disk(output: List) -> List: | |
return [ | |
item.value() if isinstance(item, CachedNumpyArray) else item | |
for item in output | |
] | |
def _output_to_list(output) -> List: | |
if isinstance(output, list): | |
return output | |
elif isinstance(output, tuple): | |
return list(output) | |
else: | |
return [output] | |
def _list_to_output(output: List): | |
if len(output) == 1: | |
return output[0] | |
return output | |
def get(self, args) -> Optional[List]: | |
key = self._args_to_key(args) | |
if key in self._data: | |
logger.debug("Cache hit") | |
self._access_time[key] = time.time() | |
return self._list_to_output(self._read_arrays_from_disk(self._data[key])) | |
logger.debug("Cache miss") | |
return None | |
def put(self, args, output): | |
key = self._args_to_key(args) | |
self._data[key] = self._write_arrays_to_disk(self._output_to_list(output)) | |
self._bytes[key] = self._estimate_bytes(output) | |
self._access_time[key] = time.time() | |
self._enforce_limits() | |
def drop(self, key): | |
del self._data[key] | |
del self._bytes[key] | |
del self._access_time[key] | |
def cached(run): | |
cache = NodeOutputCache() | |
def _run(*args): | |
out = cache.get(args) | |
if out is not None: | |
return out | |
output = run(*args) | |
cache.put(args, output) | |
return output | |
return _run | |