bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
import functools
import hashlib
import os
import tempfile
import time
from enum import Enum
from typing import Dict, List, Optional, Tuple
import numpy as np
from nodes.log import logger
CACHE_MAX_BYTES = int(os.environ.get("CACHE_MAX_BYTES", 1024**3)) # default 1 GiB
CACHE_REGISTRY: List["NodeOutputCache"] = []
class CachedNumpyArray:
def __init__(self, arr: np.ndarray):
self.file = tempfile.TemporaryFile()
self.file.write(arr.tobytes())
self.shape = arr.shape
self.dtype = arr.dtype
def value(self) -> np.ndarray:
self.file.seek(0)
return np.frombuffer(self.file.read(), dtype=self.dtype).reshape(self.shape)
class NodeOutputCache:
def __init__(self):
self._data: Dict[Tuple, List] = {}
self._bytes: Dict[Tuple, int] = {}
self._access_time: Dict[Tuple, float] = {}
CACHE_REGISTRY.append(self)
@staticmethod
def _args_to_key(args) -> Tuple:
key = []
for arg in args:
if isinstance(arg, (int, float, bool, str, bytes)):
key.append(arg)
elif arg is None:
key.append(None)
elif isinstance(arg, Enum):
key.append(arg.value)
elif isinstance(arg, np.ndarray):
key.append(tuple(arg.shape))
key.append(arg.dtype.str)
key.append(hashlib.sha256(arg.tobytes()).digest())
elif hasattr(arg, "cache_key_func"):
key.append(arg.__class__.__name__)
key.append(arg.cache_key_func())
else:
raise RuntimeError(f"Unexpected argument type {arg.__class__.__name__}")
return tuple(key)
@staticmethod
def _estimate_bytes(output) -> int:
size = 0
for out in output:
if isinstance(out, np.ndarray):
size += out.nbytes
else:
# any other type but numpy arrays is probably negligible, but here's an overestimate to handle
# pathological cases where someone has a pipeline with a million math nodes
size += 1024 # 1 KiB
return size
def empty(self):
return len(self._data) == 0
def oldest(self) -> Tuple[Tuple, float]:
return min(self._access_time.items(), key=lambda x: x[1])
def size(self):
return sum(self._bytes.values())
@staticmethod
def _enforce_limits():
while True:
total_bytes = sum([cache.size() for cache in CACHE_REGISTRY])
logger.debug(
f"Cache size: {total_bytes} ({100*total_bytes/CACHE_MAX_BYTES:0.1f}% of limit)"
)
if total_bytes <= CACHE_MAX_BYTES:
return
logger.debug("Dropping oldest cache key")
oldest_keys = [
(cache, cache.oldest()) for cache in CACHE_REGISTRY if not cache.empty()
]
cache, (key, _) = min(oldest_keys, key=lambda x: x[1][1])
cache.drop(key)
@staticmethod
def _write_arrays_to_disk(output: List) -> List:
return [
CachedNumpyArray(item) if isinstance(item, np.ndarray) else item
for item in output
]
@staticmethod
def _read_arrays_from_disk(output: List) -> List:
return [
item.value() if isinstance(item, CachedNumpyArray) else item
for item in output
]
@staticmethod
def _output_to_list(output) -> List:
if isinstance(output, list):
return output
elif isinstance(output, tuple):
return list(output)
else:
return [output]
@staticmethod
def _list_to_output(output: List):
if len(output) == 1:
return output[0]
return output
def get(self, args) -> Optional[List]:
key = self._args_to_key(args)
if key in self._data:
logger.debug("Cache hit")
self._access_time[key] = time.time()
return self._list_to_output(self._read_arrays_from_disk(self._data[key]))
logger.debug("Cache miss")
return None
def put(self, args, output):
key = self._args_to_key(args)
self._data[key] = self._write_arrays_to_disk(self._output_to_list(output))
self._bytes[key] = self._estimate_bytes(output)
self._access_time[key] = time.time()
self._enforce_limits()
def drop(self, key):
del self._data[key]
del self._bytes[key]
del self._access_time[key]
def cached(run):
cache = NodeOutputCache()
@functools.wraps(run)
def _run(*args):
out = cache.get(args)
if out is not None:
return out
output = run(*args)
cache.put(args, output)
return output
return _run