File size: 11,664 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import os
from typing import Any, Dict, Callable, Optional
import numpy as np
import torch
import diffusers
import onnxruntime as ort
import optimum.onnxruntime


initialized = False
run_olive_workflow = None


class DynamicSessionOptions(ort.SessionOptions):
    config: Optional[Dict] = None

    def __init__(self):
        super().__init__()
        self.enable_mem_pattern = False

    @classmethod
    def from_sess_options(cls, sess_options: ort.SessionOptions):
        if isinstance(sess_options, DynamicSessionOptions):
            return sess_options.copy()
        return DynamicSessionOptions()

    def enable_static_dims(self, config: Dict):
        self.config = config
        self.add_free_dimension_override_by_name("unet_sample_batch", config["hidden_batch_size"])
        self.add_free_dimension_override_by_name("unet_sample_channels", 4)
        self.add_free_dimension_override_by_name("unet_sample_height", config["height"] // 8)
        self.add_free_dimension_override_by_name("unet_sample_width", config["width"] // 8)
        self.add_free_dimension_override_by_name("unet_time_batch", 1)
        self.add_free_dimension_override_by_name("unet_hidden_batch", config["hidden_batch_size"])
        self.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
        if config["is_sdxl"] and not config["is_refiner"]:
            self.add_free_dimension_override_by_name("unet_text_embeds_batch", config["hidden_batch_size"])
            self.add_free_dimension_override_by_name("unet_text_embeds_size", 1280)
            self.add_free_dimension_override_by_name("unet_time_ids_batch", config["hidden_batch_size"])
            self.add_free_dimension_override_by_name("unet_time_ids_size", 6)

    def copy(self):
        sess_options = DynamicSessionOptions()
        if self.config is not None:
            sess_options.enable_static_dims(self.config)
        return sess_options


class TorchCompatibleModule:
    device = torch.device("cpu")
    dtype = torch.float32

    def to(self, *_, **__):
        raise NotImplementedError

    def type(self, *_, **__):
        return self


class TemporalModule(TorchCompatibleModule):
    """
    Replace the models which are not able to be moved to CPU.
    """
    provider: Any
    path: str
    sess_options: ort.SessionOptions

    def __init__(self, provider: Any, path: str, sess_options: ort.SessionOptions):
        self.provider = provider
        self.path = path
        self.sess_options = sess_options

    def to(self, *args, **kwargs):
        from .utils import extract_device

        device = extract_device(args, kwargs)
        if device is not None and device.type != "cpu":
            from .execution_providers import TORCH_DEVICE_TO_EP
            provider = TORCH_DEVICE_TO_EP[device.type] if device.type in TORCH_DEVICE_TO_EP else self.provider
            return OnnxRuntimeModel.load_model(self.path, provider, DynamicSessionOptions.from_sess_options(self.sess_options))
        return self


class OnnxRuntimeModel(TorchCompatibleModule, diffusers.OnnxRuntimeModel):
    config = {} # dummy

    def named_modules(self): # dummy
        return ()

    def to(self, *args, **kwargs):
        from modules.onnx_impl.utils import extract_device, move_inference_session

        device = extract_device(args, kwargs)
        if device is not None:
            self.device = device
            self.model = move_inference_session(self.model, device)
        return self


class VAEConfig:
    DEFAULTS = { "scaling_factor": 0.18215 }
    config: Dict

    def __init__(self, config: Dict):
        self.config = config

    def __getattr__(self, key):
        return self.config.get(key, VAEConfig.DEFAULTS[key])


class VAE(TorchCompatibleModule):
    pipeline: Any

    def __init__(self, pipeline: Any):
        self.pipeline = pipeline

    @property
    def config(self):
        return VAEConfig(self.pipeline.vae_decoder.config)

    @property
    def device(self):
        return self.pipeline.vae_decoder.device

    def encode(self, sample: torch.Tensor, *_, **__):
        sample_np = sample.cpu().numpy()
        return [
            torch.from_numpy(np.concatenate(
                [self.pipeline.vae_encoder(sample=sample_np[i : i + 1])[0] for i in range(sample_np.shape[0])]
            )).to(sample.device)
        ]

    def decode(self, latent_sample: torch.Tensor, *_, **__):
        latents_np = latent_sample.cpu().numpy()
        return [
            torch.from_numpy(np.concatenate(
                [self.pipeline.vae_decoder(latent_sample=latents_np[i : i + 1])[0] for i in range(latents_np.shape[0])]
            )).to(latent_sample.device)
        ]

    def to(self, *args, **kwargs):
        self.pipeline.vae_encoder = self.pipeline.vae_encoder.to(*args, **kwargs)
        self.pipeline.vae_decoder = self.pipeline.vae_decoder.to(*args, **kwargs)
        return self


def check_parameters_changed(p, refiner_enabled: bool):
    from modules import shared, sd_models
    if shared.sd_model.__class__.__name__ == "OnnxRawPipeline" or not shared.sd_model.__class__.__name__.startswith("Onnx"):
        return shared.sd_model
    compile_height = p.height
    compile_width = p.width
    if (shared.compiled_model_state is None or
    shared.compiled_model_state.height != compile_height
    or shared.compiled_model_state.width != compile_width
    or shared.compiled_model_state.batch_size != p.batch_size):
        shared.log.info("Olive: Parameter change detected")
        shared.log.info("Olive: Recompiling base model")
        sd_models.unload_model_weights(op='model')
        sd_models.reload_model_weights(op='model')
        if refiner_enabled:
            shared.log.info("Olive: Recompiling refiner")
            sd_models.unload_model_weights(op='refiner')
            sd_models.reload_model_weights(op='refiner')
    shared.compiled_model_state.height = compile_height
    shared.compiled_model_state.width = compile_width
    shared.compiled_model_state.batch_size = p.batch_size
    return shared.sd_model


def preprocess_pipeline(p):
    from modules import shared, sd_models
    if "ONNX" not in shared.opts.diffusers_pipeline:
        shared.log.warning(f"Unsupported pipeline for 'olive-ai' compile backend: {shared.opts.diffusers_pipeline}. You should select one of the ONNX pipelines.")
        return shared.sd_model
    if hasattr(shared.sd_model, "preprocess"):
        shared.sd_model = shared.sd_model.preprocess(p)
    if hasattr(shared.sd_refiner, "preprocess"):
        if shared.opts.onnx_unload_base:
            sd_models.unload_model_weights(op='model')
        shared.sd_refiner = shared.sd_refiner.preprocess(p)
        if shared.opts.onnx_unload_base:
            sd_models.reload_model_weights(op='model')
            shared.sd_model = shared.sd_model.preprocess(p)
    return shared.sd_model


def ORTDiffusionModelPart_to(self, *args, **kwargs):
    self.parent_model = self.parent_model.to(*args, **kwargs)
    return self


def initialize_onnx():
    global initialized # pylint: disable=global-statement
    if initialized:
        return
    from installer import log, installed
    from modules import devices
    from modules.shared import opts
    if not installed('onnx', quiet=True):
        return
    try: # may fail on onnx import
        import onnx # pylint: disable=unused-import
        from .execution_providers import ExecutionProvider, TORCH_DEVICE_TO_EP, available_execution_providers
        if devices.backend == "rocm":
            TORCH_DEVICE_TO_EP["cuda"] = ExecutionProvider.ROCm
        from .pipelines.onnx_stable_diffusion_pipeline import OnnxStableDiffusionPipeline
        from .pipelines.onnx_stable_diffusion_img2img_pipeline import OnnxStableDiffusionImg2ImgPipeline
        from .pipelines.onnx_stable_diffusion_inpaint_pipeline import OnnxStableDiffusionInpaintPipeline
        from .pipelines.onnx_stable_diffusion_upscale_pipeline import OnnxStableDiffusionUpscalePipeline
        from .pipelines.onnx_stable_diffusion_xl_pipeline import OnnxStableDiffusionXLPipeline
        from .pipelines.onnx_stable_diffusion_xl_img2img_pipeline import OnnxStableDiffusionXLImg2ImgPipeline

        OnnxRuntimeModel.__module__ = 'diffusers' # OnnxRuntimeModel Hijack.
        diffusers.OnnxRuntimeModel = OnnxRuntimeModel

        diffusers.OnnxStableDiffusionPipeline = OnnxStableDiffusionPipeline
        diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion"] = diffusers.OnnxStableDiffusionPipeline

        diffusers.OnnxStableDiffusionImg2ImgPipeline = OnnxStableDiffusionImg2ImgPipeline
        diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion"] = diffusers.OnnxStableDiffusionImg2ImgPipeline

        diffusers.OnnxStableDiffusionInpaintPipeline = OnnxStableDiffusionInpaintPipeline
        diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["onnx-stable-diffusion"] = diffusers.OnnxStableDiffusionInpaintPipeline

        diffusers.OnnxStableDiffusionUpscalePipeline = OnnxStableDiffusionUpscalePipeline

        diffusers.OnnxStableDiffusionXLPipeline = OnnxStableDiffusionXLPipeline
        diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion-xl"] = diffusers.OnnxStableDiffusionXLPipeline

        diffusers.OnnxStableDiffusionXLImg2ImgPipeline = OnnxStableDiffusionXLImg2ImgPipeline
        diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion-xl"] = diffusers.OnnxStableDiffusionXLImg2ImgPipeline

        diffusers.ORTStableDiffusionXLPipeline = diffusers.OnnxStableDiffusionXLPipeline # Huggingface model compatibility
        diffusers.ORTStableDiffusionXLImg2ImgPipeline = diffusers.OnnxStableDiffusionXLImg2ImgPipeline

        optimum.onnxruntime.modeling_diffusion._ORTDiffusionModelPart.to = ORTDiffusionModelPart_to # pylint: disable=protected-access

        log.debug(f'ONNX: version={ort.__version__} provider={opts.onnx_execution_provider}, available={available_execution_providers}')
    except Exception as e:
        log.error(f'ONNX failed to initialize: {e}')
    initialized = True


def initialize_olive():
    global run_olive_workflow # pylint: disable=global-statement
    from installer import installed, log
    if not installed('olive-ai', quiet=True) or not installed('onnx', quiet=True):
        return
    import sys
    import importlib
    orig_sys_path = sys.path
    venv_dir = os.environ.get("VENV_DIR", os.path.join(os.getcwd(), 'venv'))
    try:
        spec = importlib.util.find_spec('onnxruntime.transformers')
        sys.path = [d for d in spec.submodule_search_locations + sys.path if sys.path[1] not in d or venv_dir in d]
        from onnxruntime.transformers import convert_generation # pylint: disable=unused-import
        spec = importlib.util.find_spec('olive')
        sys.path = spec.submodule_search_locations + sys.path
        run_olive_workflow = importlib.import_module('olive.workflows').run
    except Exception as e:
        run_olive_workflow = None
        log.error(f'Olive: Failed to load olive-ai: {e}')
    sys.path = orig_sys_path


def install_olive():
    from installer import installed, install, log
    if installed("olive-ai"):
        return
    try:
        log.info('Installing Olive')
        install('onnx', 'onnx', ignore=True)
        install('olive-ai', 'olive-ai', ignore=True)
        import olive.workflows # pylint: disable=unused-import
    except Exception as e:
        log.error(f'Olive: Failed to load olive-ai: {e}')
    else:
        log.info('Olive: Please restart webui session.')