Spaces:
Runtime error
Runtime error
from typing import Type | |
import onnxruntime as ort | |
from ...impl.onnx.utils import get_input_shape | |
from .session_base import BaseSession | |
from .session_cloth import ClothSession | |
from .session_simple import SimpleSession | |
def new_session(session: ort.InferenceSession) -> BaseSession: | |
session_class: Type[BaseSession] | |
input_width = get_input_shape(session)[2] | |
# Using size to determine session type and norm parameters is fragile, | |
# but at the moment I don't know a better way to detect architecture due | |
# to the lack of consistency in naming and outputs across arches and repos. | |
# It works right now because of the limited number of models supported, | |
# but if that expands, it may become necessary to find an alternative. | |
mean = (0.485, 0.456, 0.406) | |
std = (0.229, 0.224, 0.225) | |
size = (input_width, input_width) if input_width is not None else (320, 320) | |
if input_width == 768: # U2NET cloth model | |
session_class = ClothSession | |
mean = (0.5, 0.5, 0.5) | |
std = (0.5, 0.5, 0.5) | |
else: | |
session_class = SimpleSession | |
if input_width == 1024: # ISNET | |
mean = (0.5, 0.5, 0.5) | |
std = (1, 1, 1) | |
elif input_width == 512: # Models trained using anime-segmentation repo | |
mean = (0, 0, 0) | |
std = (1, 1, 1) | |
return session_class(session, mean, std, size) | |