Spaces:
Runtime error
Runtime error
# This class defines an interface. | |
# It is important that is does not contain types that depend on ONNX. | |
from typing import Union | |
import re2 | |
re2_options = re2.Options() | |
re2_options.dot_nl = True | |
re2_options.encoding = re2.Options.Encoding.LATIN1 | |
U2NET_STANDARD = re2.compile(b"1959.+1960.+1961.+1962.+1963.+1964.+1965", re2_options) | |
U2NET_CLOTH = re2.compile( | |
b"output.+d1.+Concat_1876.+Concat_1896.+Concat_1916.+Concat_1936.+Concat_1956", | |
re2_options, | |
) | |
U2NET_SILUETA = re2.compile(b"1808.+1827.+1828.+2296.+1831.+1850.+1958", re2_options) | |
U2NET_ISNET = re2.compile( | |
b"/stage1/rebnconvin/conv_s1/Conv.+/stage1/rebnconvin/relu_s1/Relu", re2_options | |
) | |
class OnnxGeneric: | |
def __init__(self, model_as_bytes: bytes): | |
self.bytes: bytes = model_as_bytes | |
self.sub_type = "Generic" | |
self.scale_height = None | |
self.scale_width = None | |
class OnnxRemBg: | |
def __init__(self, model_as_bytes: bytes, scale_height: int = 1): | |
self.bytes: bytes = model_as_bytes | |
self.sub_type = "RemBg" | |
self.scale_height = scale_height | |
self.scale_width = 1 | |
OnnxModels = (OnnxGeneric, OnnxRemBg) | |
OnnxModel = Union[OnnxGeneric, OnnxRemBg] | |
def is_rembg_model(model_as_bytes: bytes) -> bool: | |
if ( | |
U2NET_STANDARD.search(model_as_bytes[-600:]) is not None | |
or U2NET_CLOTH.search(model_as_bytes[-1000:]) is not None | |
or U2NET_SILUETA.search(model_as_bytes[-600:]) is not None | |
or U2NET_ISNET.search(model_as_bytes[:10000]) is not None | |
): | |
return True | |
return False | |
def load_onnx_model(model_as_bytes: bytes) -> OnnxModel: | |
if ( | |
U2NET_STANDARD.search(model_as_bytes[-1000:]) is not None | |
or U2NET_SILUETA.search(model_as_bytes[-600:]) is not None | |
or U2NET_ISNET.search(model_as_bytes[:10000]) is not None | |
): | |
model = OnnxRemBg(model_as_bytes) | |
elif U2NET_CLOTH.search(model_as_bytes[-1000:]) is not None: | |
model = OnnxRemBg(model_as_bytes, scale_height=3) | |
else: | |
model = OnnxGeneric(model_as_bytes) | |
return model | |