Spaces:
Runtime error
Runtime error
File size: 952 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# Safe unpickler to prevent arbitrary code execution
import pickle
from types import SimpleNamespace
safe_list = {
("collections", "OrderedDict"),
("typing", "OrderedDict"),
("torch._utils", "_rebuild_tensor_v2"),
("torch", "BFloat16Storage"),
("torch", "FloatStorage"),
("torch", "HalfStorage"),
("torch", "IntStorage"),
("torch", "LongStorage"),
("torch", "DoubleStorage"),
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Only allow required classes to load state dict
if (module, name) not in safe_list:
raise pickle.UnpicklingError(
"Global '{}.{}' is forbidden".format(module, name)
)
return super().find_class(module, name)
RestrictedUnpickle = SimpleNamespace(
Unpickler=RestrictedUnpickler,
__name__="pickle",
load=lambda *args, **kwargs: RestrictedUnpickler(*args, **kwargs).load(),
)
|