File size: 1,601 Bytes
57746f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
Data Cache Utils

Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""

import os
import SharedArray

try:
    from multiprocessing.shared_memory import ShareableList
except ImportError:
    import warnings

    warnings.warn("Please update python version >= 3.8 to enable shared_memory")
import numpy as np


def shared_array(name, var=None):
    if var is not None:
        # check exist
        if os.path.exists(f"/dev/shm/{name}"):
            return SharedArray.attach(f"shm://{name}")
        # create shared_array
        data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype)
        data[...] = var[...]
        data.flags.writeable = False
    else:
        data = SharedArray.attach(f"shm://{name}").copy()
    return data


def shared_dict(name, var=None):
    name = str(name)
    assert "." not in name  # '.' is used as sep flag
    data = {}
    if var is not None:
        assert isinstance(var, dict)
        keys = var.keys()
        # current version only cache np.array
        keys_valid = []
        for key in keys:
            if isinstance(var[key], np.ndarray):
                keys_valid.append(key)
        keys = keys_valid

        ShareableList(sequence=keys, name=name + ".keys")
        for key in keys:
            if isinstance(var[key], np.ndarray):
                data[key] = shared_array(name=f"{name}.{key}", var=var[key])
    else:
        keys = list(ShareableList(name=name + ".keys"))
        for key in keys:
            data[key] = shared_array(name=f"{name}.{key}")
    return data