Spaces:
Runtime error
Runtime error
File size: 5,324 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
import json
import importlib
from typing import Type, Tuple, Union, List, Dict, Any
import torch
import diffusers
import onnxruntime as ort
def extract_device(args: List, kwargs: Dict):
device = kwargs.get("device", None)
if device is None:
for arg in args:
if isinstance(arg, torch.device):
device = arg
return device
def move_inference_session(session: ort.InferenceSession, device: torch.device):
from modules.devices import device as default_device
if default_device.type == "cpu": # CPU-only torch without any other external ops overriding. This transfer will be led to mistake.
return session
from . import DynamicSessionOptions, TemporalModule
from .execution_providers import TORCH_DEVICE_TO_EP
previous_provider = session._providers # pylint: disable=protected-access
provider = TORCH_DEVICE_TO_EP[device.type] if device.type in TORCH_DEVICE_TO_EP else previous_provider
path = session._model_path # pylint: disable=protected-access
try:
return diffusers.OnnxRuntimeModel.load_model(path, provider, DynamicSessionOptions.from_sess_options(session._sess_options)) # pylint: disable=protected-access
except Exception:
return TemporalModule(previous_provider, path, session._sess_options) # pylint: disable=protected-access
def check_diffusers_cache(path: os.PathLike):
from modules.shared import opts
return opts.diffusers_dir in os.path.abspath(path)
def check_pipeline_sdxl(cls: Type[diffusers.DiffusionPipeline]) -> bool:
return 'XL' in cls.__name__
def check_cache_onnx(path: os.PathLike) -> bool:
if not os.path.isdir(path):
return False
init_dict_path = os.path.join(path, "model_index.json")
if not os.path.isfile(init_dict_path):
return False
init_dict = None
with open(init_dict_path, "r", encoding="utf-8") as file:
init_dict = file.read()
if "OnnxRuntimeModel" not in init_dict:
return False
return True
def load_init_dict(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike):
merged: Dict[str, Any] = {}
extracted = cls.extract_init_dict(diffusers.DiffusionPipeline.load_config(path))
for item in extracted:
merged.update(item)
merged = merged.items()
R: Dict[str, Tuple[str]] = {}
for k, v in merged:
if isinstance(v, list):
if k not in cls.__init__.__annotations__:
continue
R[k] = v
return R
def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: List[Union[str, None]], **kwargs_ort):
lib, atr = item
if lib is None or atr is None:
return None
library = importlib.import_module(lib)
attribute = getattr(library, atr)
path = os.path.join(path, submodel_name)
if issubclass(attribute, diffusers.OnnxRuntimeModel):
return diffusers.OnnxRuntimeModel.load_model(
os.path.join(path, "model.onnx"),
**kwargs_ort,
) if is_sdxl else diffusers.OnnxRuntimeModel.from_pretrained(
path,
**kwargs_ort,
)
return attribute.from_pretrained(path)
def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: Dict[str, Type], **kwargs_ort):
loaded = {}
for k, v in init_dict.items():
if not isinstance(v, list):
loaded[k] = v
continue
try:
loaded[k] = load_submodel(path, is_sdxl, k, v, **kwargs_ort)
except Exception:
pass
return loaded
def load_pipeline(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike, **kwargs_ort) -> diffusers.DiffusionPipeline:
if os.path.isdir(path):
return cls(**patch_kwargs(cls, load_submodels(path, check_pipeline_sdxl(cls), load_init_dict(cls, path), **kwargs_ort)))
else:
return cls.from_single_file(path)
def patch_kwargs(cls: Type[diffusers.DiffusionPipeline], kwargs: Dict) -> Dict:
if cls == diffusers.OnnxStableDiffusionPipeline or cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline:
kwargs["safety_checker"] = None
kwargs["requires_safety_checker"] = False
if cls == diffusers.OnnxStableDiffusionXLPipeline or cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline:
kwargs["config"] = {}
return kwargs
def get_base_constructor(cls: Type[diffusers.DiffusionPipeline], is_refiner: bool):
if cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline:
return diffusers.OnnxStableDiffusionPipeline
if cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline and not is_refiner:
return diffusers.OnnxStableDiffusionXLPipeline
return cls
def get_io_config(submodel: str, is_sdxl: bool):
from modules.paths import sd_configs_path
with open(os.path.join(sd_configs_path, "olive", 'sdxl' if is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file:
io_config: Dict[str, Any] = json.load(config_file)["input_model"]["config"]["io_config"]
for axe in io_config["dynamic_axes"]:
io_config["dynamic_axes"][axe] = { int(k): v for k, v in io_config["dynamic_axes"][axe].items() }
return io_config
|