Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import subprocess | |
import tempfile | |
class PlasmaArray(object): | |
""" | |
Wrapper around numpy arrays that automatically moves the data to shared | |
memory upon serialization. This is particularly helpful when passing numpy | |
arrays through multiprocessing, so that data is not unnecessarily | |
duplicated or pickled. | |
""" | |
def __init__(self, array): | |
super().__init__() | |
self.array = array | |
self.disable = array.nbytes < 134217728 # disable for arrays <128MB | |
self.object_id = None | |
self.path = None | |
# variables with underscores shouldn't be pickled | |
self._client = None | |
self._server = None | |
self._server_tmp = None | |
self._plasma = None | |
def plasma(self): | |
if self._plasma is None and not self.disable: | |
try: | |
import pyarrow.plasma as plasma | |
self._plasma = plasma | |
except ImportError: | |
self._plasma = None | |
return self._plasma | |
def start_server(self): | |
if self.plasma is None or self._server is not None: | |
return | |
assert self.object_id is None | |
assert self.path is None | |
self._server_tmp = tempfile.NamedTemporaryFile() | |
self.path = self._server_tmp.name | |
self._server = subprocess.Popen( | |
[ | |
"plasma_store", | |
"-m", | |
str(int(1.05 * self.array.nbytes)), | |
"-s", | |
self.path, | |
] | |
) | |
def client(self): | |
if self._client is None: | |
assert self.path is not None | |
self._client = self.plasma.connect(self.path, num_retries=200) | |
return self._client | |
def __getstate__(self): | |
if self.plasma is None: | |
return self.__dict__ | |
if self.object_id is None: | |
self.start_server() | |
self.object_id = self.client.put(self.array) | |
state = self.__dict__.copy() | |
del state["array"] | |
state["_client"] = None | |
state["_server"] = None | |
state["_server_tmp"] = None | |
state["_plasma"] = None | |
return state | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
if self.plasma is None: | |
return | |
self.array = self.client.get(self.object_id) | |
def __del__(self): | |
if self._server is not None: | |
self._server.kill() | |
self._server = None | |
self._server_tmp.close() | |
self._server_tmp = None | |