bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
import os
import json
import importlib
from typing import Type, Tuple, Union, List, Dict, Any
import torch
import diffusers
import onnxruntime as ort
def extract_device(args: List, kwargs: Dict):
device = kwargs.get("device", None)
if device is None:
for arg in args:
if isinstance(arg, torch.device):
device = arg
return device
def move_inference_session(session: ort.InferenceSession, device: torch.device):
from modules.devices import device as default_device
if default_device.type == "cpu": # CPU-only torch without any other external ops overriding. This transfer will be led to mistake.
return session
from . import DynamicSessionOptions, TemporalModule
from .execution_providers import TORCH_DEVICE_TO_EP
previous_provider = session._providers # pylint: disable=protected-access
provider = TORCH_DEVICE_TO_EP[device.type] if device.type in TORCH_DEVICE_TO_EP else previous_provider
path = session._model_path # pylint: disable=protected-access
try:
return diffusers.OnnxRuntimeModel.load_model(path, provider, DynamicSessionOptions.from_sess_options(session._sess_options)) # pylint: disable=protected-access
except Exception:
return TemporalModule(previous_provider, path, session._sess_options) # pylint: disable=protected-access
def check_diffusers_cache(path: os.PathLike):
from modules.shared import opts
return opts.diffusers_dir in os.path.abspath(path)
def check_pipeline_sdxl(cls: Type[diffusers.DiffusionPipeline]) -> bool:
return 'XL' in cls.__name__
def check_cache_onnx(path: os.PathLike) -> bool:
if not os.path.isdir(path):
return False
init_dict_path = os.path.join(path, "model_index.json")
if not os.path.isfile(init_dict_path):
return False
init_dict = None
with open(init_dict_path, "r", encoding="utf-8") as file:
init_dict = file.read()
if "OnnxRuntimeModel" not in init_dict:
return False
return True
def load_init_dict(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike):
merged: Dict[str, Any] = {}
extracted = cls.extract_init_dict(diffusers.DiffusionPipeline.load_config(path))
for item in extracted:
merged.update(item)
merged = merged.items()
R: Dict[str, Tuple[str]] = {}
for k, v in merged:
if isinstance(v, list):
if k not in cls.__init__.__annotations__:
continue
R[k] = v
return R
def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: List[Union[str, None]], **kwargs_ort):
lib, atr = item
if lib is None or atr is None:
return None
library = importlib.import_module(lib)
attribute = getattr(library, atr)
path = os.path.join(path, submodel_name)
if issubclass(attribute, diffusers.OnnxRuntimeModel):
return diffusers.OnnxRuntimeModel.load_model(
os.path.join(path, "model.onnx"),
**kwargs_ort,
) if is_sdxl else diffusers.OnnxRuntimeModel.from_pretrained(
path,
**kwargs_ort,
)
return attribute.from_pretrained(path)
def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: Dict[str, Type], **kwargs_ort):
loaded = {}
for k, v in init_dict.items():
if not isinstance(v, list):
loaded[k] = v
continue
try:
loaded[k] = load_submodel(path, is_sdxl, k, v, **kwargs_ort)
except Exception:
pass
return loaded
def load_pipeline(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike, **kwargs_ort) -> diffusers.DiffusionPipeline:
if os.path.isdir(path):
return cls(**patch_kwargs(cls, load_submodels(path, check_pipeline_sdxl(cls), load_init_dict(cls, path), **kwargs_ort)))
else:
return cls.from_single_file(path)
def patch_kwargs(cls: Type[diffusers.DiffusionPipeline], kwargs: Dict) -> Dict:
if cls == diffusers.OnnxStableDiffusionPipeline or cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline:
kwargs["safety_checker"] = None
kwargs["requires_safety_checker"] = False
if cls == diffusers.OnnxStableDiffusionXLPipeline or cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline:
kwargs["config"] = {}
return kwargs
def get_base_constructor(cls: Type[diffusers.DiffusionPipeline], is_refiner: bool):
if cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline:
return diffusers.OnnxStableDiffusionPipeline
if cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline and not is_refiner:
return diffusers.OnnxStableDiffusionXLPipeline
return cls
def get_io_config(submodel: str, is_sdxl: bool):
from modules.paths import sd_configs_path
with open(os.path.join(sd_configs_path, "olive", 'sdxl' if is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file:
io_config: Dict[str, Any] = json.load(config_file)["input_model"]["config"]["io_config"]
for axe in io_config["dynamic_axes"]:
io_config["dynamic_axes"][axe] = { int(k): v for k, v in io_config["dynamic_axes"][axe].items() }
return io_config