File size: 4,850 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import functools
import hashlib
import os
import tempfile
import time
from enum import Enum
from typing import Dict, List, Optional, Tuple

import numpy as np
from nodes.log import logger

CACHE_MAX_BYTES = int(os.environ.get("CACHE_MAX_BYTES", 1024**3))  # default 1 GiB
CACHE_REGISTRY: List["NodeOutputCache"] = []


class CachedNumpyArray:
    def __init__(self, arr: np.ndarray):
        self.file = tempfile.TemporaryFile()
        self.file.write(arr.tobytes())

        self.shape = arr.shape
        self.dtype = arr.dtype

    def value(self) -> np.ndarray:
        self.file.seek(0)
        return np.frombuffer(self.file.read(), dtype=self.dtype).reshape(self.shape)


class NodeOutputCache:
    def __init__(self):
        self._data: Dict[Tuple, List] = {}
        self._bytes: Dict[Tuple, int] = {}
        self._access_time: Dict[Tuple, float] = {}

        CACHE_REGISTRY.append(self)

    @staticmethod
    def _args_to_key(args) -> Tuple:
        key = []
        for arg in args:
            if isinstance(arg, (int, float, bool, str, bytes)):
                key.append(arg)
            elif arg is None:
                key.append(None)
            elif isinstance(arg, Enum):
                key.append(arg.value)
            elif isinstance(arg, np.ndarray):
                key.append(tuple(arg.shape))
                key.append(arg.dtype.str)
                key.append(hashlib.sha256(arg.tobytes()).digest())
            elif hasattr(arg, "cache_key_func"):
                key.append(arg.__class__.__name__)
                key.append(arg.cache_key_func())
            else:
                raise RuntimeError(f"Unexpected argument type {arg.__class__.__name__}")
        return tuple(key)

    @staticmethod
    def _estimate_bytes(output) -> int:
        size = 0
        for out in output:
            if isinstance(out, np.ndarray):
                size += out.nbytes
            else:
                # any other type but numpy arrays is probably negligible, but here's an overestimate to handle
                # pathological cases where someone has a pipeline with a million math nodes
                size += 1024  # 1 KiB
        return size

    def empty(self):
        return len(self._data) == 0

    def oldest(self) -> Tuple[Tuple, float]:
        return min(self._access_time.items(), key=lambda x: x[1])

    def size(self):
        return sum(self._bytes.values())

    @staticmethod
    def _enforce_limits():
        while True:
            total_bytes = sum([cache.size() for cache in CACHE_REGISTRY])
            logger.debug(
                f"Cache size: {total_bytes} ({100*total_bytes/CACHE_MAX_BYTES:0.1f}% of limit)"
            )
            if total_bytes <= CACHE_MAX_BYTES:
                return
            logger.debug("Dropping oldest cache key")

            oldest_keys = [
                (cache, cache.oldest()) for cache in CACHE_REGISTRY if not cache.empty()
            ]

            cache, (key, _) = min(oldest_keys, key=lambda x: x[1][1])
            cache.drop(key)

    @staticmethod
    def _write_arrays_to_disk(output: List) -> List:
        return [
            CachedNumpyArray(item) if isinstance(item, np.ndarray) else item
            for item in output
        ]

    @staticmethod
    def _read_arrays_from_disk(output: List) -> List:
        return [
            item.value() if isinstance(item, CachedNumpyArray) else item
            for item in output
        ]

    @staticmethod
    def _output_to_list(output) -> List:
        if isinstance(output, list):
            return output
        elif isinstance(output, tuple):
            return list(output)
        else:
            return [output]

    @staticmethod
    def _list_to_output(output: List):
        if len(output) == 1:
            return output[0]
        return output

    def get(self, args) -> Optional[List]:
        key = self._args_to_key(args)
        if key in self._data:
            logger.debug("Cache hit")
            self._access_time[key] = time.time()
            return self._list_to_output(self._read_arrays_from_disk(self._data[key]))
        logger.debug("Cache miss")
        return None

    def put(self, args, output):
        key = self._args_to_key(args)
        self._data[key] = self._write_arrays_to_disk(self._output_to_list(output))
        self._bytes[key] = self._estimate_bytes(output)
        self._access_time[key] = time.time()
        self._enforce_limits()

    def drop(self, key):
        del self._data[key]
        del self._bytes[key]
        del self._access_time[key]


def cached(run):
    cache = NodeOutputCache()

    @functools.wraps(run)
    def _run(*args):
        out = cache.get(args)
        if out is not None:
            return out
        output = run(*args)
        cache.put(args, output)
        return output

    return _run