File size: 5,324 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import json
import importlib
from typing import Type, Tuple, Union, List, Dict, Any
import torch
import diffusers
import onnxruntime as ort


def extract_device(args: List, kwargs: Dict):
    device = kwargs.get("device", None)

    if device is None:
        for arg in args:
            if isinstance(arg, torch.device):
                device = arg

    return device


def move_inference_session(session: ort.InferenceSession, device: torch.device):
    from modules.devices import device as default_device

    if default_device.type == "cpu": # CPU-only torch without any other external ops overriding. This transfer will be led to mistake.
        return session

    from . import DynamicSessionOptions, TemporalModule
    from .execution_providers import TORCH_DEVICE_TO_EP

    previous_provider = session._providers # pylint: disable=protected-access
    provider = TORCH_DEVICE_TO_EP[device.type] if device.type in TORCH_DEVICE_TO_EP else previous_provider
    path = session._model_path # pylint: disable=protected-access

    try:
        return diffusers.OnnxRuntimeModel.load_model(path, provider, DynamicSessionOptions.from_sess_options(session._sess_options)) # pylint: disable=protected-access
    except Exception:
        return TemporalModule(previous_provider, path, session._sess_options) # pylint: disable=protected-access


def check_diffusers_cache(path: os.PathLike):
    from modules.shared import opts
    return opts.diffusers_dir in os.path.abspath(path)


def check_pipeline_sdxl(cls: Type[diffusers.DiffusionPipeline]) -> bool:
    return 'XL' in cls.__name__


def check_cache_onnx(path: os.PathLike) -> bool:
    if not os.path.isdir(path):
        return False

    init_dict_path = os.path.join(path, "model_index.json")

    if not os.path.isfile(init_dict_path):
        return False

    init_dict = None

    with open(init_dict_path, "r", encoding="utf-8") as file:
        init_dict = file.read()

    if "OnnxRuntimeModel" not in init_dict:
        return False

    return True


def load_init_dict(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike):
    merged: Dict[str, Any] = {}
    extracted = cls.extract_init_dict(diffusers.DiffusionPipeline.load_config(path))

    for item in extracted:
        merged.update(item)

    merged = merged.items()
    R: Dict[str, Tuple[str]] = {}

    for k, v in merged:
        if isinstance(v, list):
            if k not in cls.__init__.__annotations__:
                continue
            R[k] = v

    return R


def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: List[Union[str, None]], **kwargs_ort):
    lib, atr = item

    if lib is None or atr is None:
        return None

    library = importlib.import_module(lib)
    attribute = getattr(library, atr)
    path = os.path.join(path, submodel_name)

    if issubclass(attribute, diffusers.OnnxRuntimeModel):
        return diffusers.OnnxRuntimeModel.load_model(
            os.path.join(path, "model.onnx"),
            **kwargs_ort,
        ) if is_sdxl else diffusers.OnnxRuntimeModel.from_pretrained(
            path,
            **kwargs_ort,
        )

    return attribute.from_pretrained(path)


def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: Dict[str, Type], **kwargs_ort):
    loaded = {}

    for k, v in init_dict.items():
        if not isinstance(v, list):
            loaded[k] = v
            continue
        try:
            loaded[k] = load_submodel(path, is_sdxl, k, v, **kwargs_ort)
        except Exception:
            pass

    return loaded


def load_pipeline(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike, **kwargs_ort) -> diffusers.DiffusionPipeline:
    if os.path.isdir(path):
        return cls(**patch_kwargs(cls, load_submodels(path, check_pipeline_sdxl(cls), load_init_dict(cls, path), **kwargs_ort)))
    else:
        return cls.from_single_file(path)


def patch_kwargs(cls: Type[diffusers.DiffusionPipeline], kwargs: Dict) -> Dict:
    if cls == diffusers.OnnxStableDiffusionPipeline or cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline:
        kwargs["safety_checker"] = None
        kwargs["requires_safety_checker"] = False

    if cls == diffusers.OnnxStableDiffusionXLPipeline or cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline:
        kwargs["config"] = {}

    return kwargs


def get_base_constructor(cls: Type[diffusers.DiffusionPipeline], is_refiner: bool):
    if cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline:
        return diffusers.OnnxStableDiffusionPipeline

    if cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline and not is_refiner:
        return diffusers.OnnxStableDiffusionXLPipeline

    return cls


def get_io_config(submodel: str, is_sdxl: bool):
    from modules.paths import sd_configs_path

    with open(os.path.join(sd_configs_path, "olive", 'sdxl' if is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file:
        io_config: Dict[str, Any] = json.load(config_file)["input_model"]["config"]["io_config"]

    for axe in io_config["dynamic_axes"]:
        io_config["dynamic_axes"][axe] = { int(k): v for k, v in io_config["dynamic_axes"][axe].items() }

    return io_config