File size: 1,386 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Type

import onnxruntime as ort

from ...impl.onnx.utils import get_input_shape
from .session_base import BaseSession
from .session_cloth import ClothSession
from .session_simple import SimpleSession


def new_session(session: ort.InferenceSession) -> BaseSession:
    session_class: Type[BaseSession]

    input_width = get_input_shape(session)[2]

    # Using size to determine session type and norm parameters is fragile,
    # but at the moment I don't know a better way to detect architecture due
    # to the lack of consistency in naming and outputs across arches and repos.
    # It works right now because of the limited number of models supported,
    # but if that expands, it may become necessary to find an alternative.
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    size = (input_width, input_width) if input_width is not None else (320, 320)
    if input_width == 768:  # U2NET cloth model
        session_class = ClothSession
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
    else:
        session_class = SimpleSession
        if input_width == 1024:  # ISNET
            mean = (0.5, 0.5, 0.5)
            std = (1, 1, 1)
        elif input_width == 512:  # Models trained using anime-segmentation repo
            mean = (0, 0, 0)
            std = (1, 1, 1)

    return session_class(session, mean, std, size)