|
""" |
|
Data Cache Utils |
|
|
|
Author: Xiaoyang Wu ([email protected]) |
|
Please cite our work if the code is helpful to you. |
|
""" |
|
|
|
import os |
|
import SharedArray |
|
|
|
try: |
|
from multiprocessing.shared_memory import ShareableList |
|
except ImportError: |
|
import warnings |
|
|
|
warnings.warn("Please update python version >= 3.8 to enable shared_memory") |
|
import numpy as np |
|
|
|
|
|
def shared_array(name, var=None): |
|
if var is not None: |
|
|
|
if os.path.exists(f"/dev/shm/{name}"): |
|
return SharedArray.attach(f"shm://{name}") |
|
|
|
data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype) |
|
data[...] = var[...] |
|
data.flags.writeable = False |
|
else: |
|
data = SharedArray.attach(f"shm://{name}").copy() |
|
return data |
|
|
|
|
|
def shared_dict(name, var=None): |
|
name = str(name) |
|
assert "." not in name |
|
data = {} |
|
if var is not None: |
|
assert isinstance(var, dict) |
|
keys = var.keys() |
|
|
|
keys_valid = [] |
|
for key in keys: |
|
if isinstance(var[key], np.ndarray): |
|
keys_valid.append(key) |
|
keys = keys_valid |
|
|
|
ShareableList(sequence=keys, name=name + ".keys") |
|
for key in keys: |
|
if isinstance(var[key], np.ndarray): |
|
data[key] = shared_array(name=f"{name}.{key}", var=var[key]) |
|
else: |
|
keys = list(ShareableList(name=name + ".keys")) |
|
for key in keys: |
|
data[key] = shared_array(name=f"{name}.{key}") |
|
return data |
|
|