Spaces:
Runtime error
Runtime error
File size: 4,850 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import functools
import hashlib
import os
import tempfile
import time
from enum import Enum
from typing import Dict, List, Optional, Tuple
import numpy as np
from nodes.log import logger
CACHE_MAX_BYTES = int(os.environ.get("CACHE_MAX_BYTES", 1024**3)) # default 1 GiB
CACHE_REGISTRY: List["NodeOutputCache"] = []
class CachedNumpyArray:
def __init__(self, arr: np.ndarray):
self.file = tempfile.TemporaryFile()
self.file.write(arr.tobytes())
self.shape = arr.shape
self.dtype = arr.dtype
def value(self) -> np.ndarray:
self.file.seek(0)
return np.frombuffer(self.file.read(), dtype=self.dtype).reshape(self.shape)
class NodeOutputCache:
def __init__(self):
self._data: Dict[Tuple, List] = {}
self._bytes: Dict[Tuple, int] = {}
self._access_time: Dict[Tuple, float] = {}
CACHE_REGISTRY.append(self)
@staticmethod
def _args_to_key(args) -> Tuple:
key = []
for arg in args:
if isinstance(arg, (int, float, bool, str, bytes)):
key.append(arg)
elif arg is None:
key.append(None)
elif isinstance(arg, Enum):
key.append(arg.value)
elif isinstance(arg, np.ndarray):
key.append(tuple(arg.shape))
key.append(arg.dtype.str)
key.append(hashlib.sha256(arg.tobytes()).digest())
elif hasattr(arg, "cache_key_func"):
key.append(arg.__class__.__name__)
key.append(arg.cache_key_func())
else:
raise RuntimeError(f"Unexpected argument type {arg.__class__.__name__}")
return tuple(key)
@staticmethod
def _estimate_bytes(output) -> int:
size = 0
for out in output:
if isinstance(out, np.ndarray):
size += out.nbytes
else:
# any other type but numpy arrays is probably negligible, but here's an overestimate to handle
# pathological cases where someone has a pipeline with a million math nodes
size += 1024 # 1 KiB
return size
def empty(self):
return len(self._data) == 0
def oldest(self) -> Tuple[Tuple, float]:
return min(self._access_time.items(), key=lambda x: x[1])
def size(self):
return sum(self._bytes.values())
@staticmethod
def _enforce_limits():
while True:
total_bytes = sum([cache.size() for cache in CACHE_REGISTRY])
logger.debug(
f"Cache size: {total_bytes} ({100*total_bytes/CACHE_MAX_BYTES:0.1f}% of limit)"
)
if total_bytes <= CACHE_MAX_BYTES:
return
logger.debug("Dropping oldest cache key")
oldest_keys = [
(cache, cache.oldest()) for cache in CACHE_REGISTRY if not cache.empty()
]
cache, (key, _) = min(oldest_keys, key=lambda x: x[1][1])
cache.drop(key)
@staticmethod
def _write_arrays_to_disk(output: List) -> List:
return [
CachedNumpyArray(item) if isinstance(item, np.ndarray) else item
for item in output
]
@staticmethod
def _read_arrays_from_disk(output: List) -> List:
return [
item.value() if isinstance(item, CachedNumpyArray) else item
for item in output
]
@staticmethod
def _output_to_list(output) -> List:
if isinstance(output, list):
return output
elif isinstance(output, tuple):
return list(output)
else:
return [output]
@staticmethod
def _list_to_output(output: List):
if len(output) == 1:
return output[0]
return output
def get(self, args) -> Optional[List]:
key = self._args_to_key(args)
if key in self._data:
logger.debug("Cache hit")
self._access_time[key] = time.time()
return self._list_to_output(self._read_arrays_from_disk(self._data[key]))
logger.debug("Cache miss")
return None
def put(self, args, output):
key = self._args_to_key(args)
self._data[key] = self._write_arrays_to_disk(self._output_to_list(output))
self._bytes[key] = self._estimate_bytes(output)
self._access_time[key] = time.time()
self._enforce_limits()
def drop(self, key):
del self._data[key]
del self._bytes[key]
del self._access_time[key]
def cached(run):
cache = NodeOutputCache()
@functools.wraps(run)
def _run(*args):
out = cache.get(args)
if out is not None:
return out
output = run(*args)
cache.put(args, output)
return output
return _run
|