Spaces:
Sleeping
Sleeping
File size: 2,728 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import subprocess
import tempfile
class PlasmaArray(object):
"""
Wrapper around numpy arrays that automatically moves the data to shared
memory upon serialization. This is particularly helpful when passing numpy
arrays through multiprocessing, so that data is not unnecessarily
duplicated or pickled.
"""
def __init__(self, array):
super().__init__()
self.array = array
self.disable = array.nbytes < 134217728 # disable for arrays <128MB
self.object_id = None
self.path = None
# variables with underscores shouldn't be pickled
self._client = None
self._server = None
self._server_tmp = None
self._plasma = None
@property
def plasma(self):
if self._plasma is None and not self.disable:
try:
import pyarrow.plasma as plasma
self._plasma = plasma
except ImportError:
self._plasma = None
return self._plasma
def start_server(self):
if self.plasma is None or self._server is not None:
return
assert self.object_id is None
assert self.path is None
self._server_tmp = tempfile.NamedTemporaryFile()
self.path = self._server_tmp.name
self._server = subprocess.Popen(
[
"plasma_store",
"-m",
str(int(1.05 * self.array.nbytes)),
"-s",
self.path,
]
)
@property
def client(self):
if self._client is None:
assert self.path is not None
self._client = self.plasma.connect(self.path, num_retries=200)
return self._client
def __getstate__(self):
if self.plasma is None:
return self.__dict__
if self.object_id is None:
self.start_server()
self.object_id = self.client.put(self.array)
state = self.__dict__.copy()
del state["array"]
state["_client"] = None
state["_server"] = None
state["_server_tmp"] = None
state["_plasma"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
if self.plasma is None:
return
self.array = self.client.get(self.object_id)
def __del__(self):
if self._server is not None:
self._server.kill()
self._server = None
self._server_tmp.close()
self._server_tmp = None
|