Spaces:
Runtime error
Runtime error
import os | |
import time | |
from typing import Union | |
import numpy as np | |
from PIL import Image | |
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline | |
from modules.shared import log, opts, listdir | |
from modules import errors | |
from modules.control.units.lite_model import ControlNetLLLite | |
what = 'ControlLLLite' | |
debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None | |
debug('Trace: CONTROL') | |
predefined_sd15 = { | |
} | |
predefined_sdxl = { | |
'Canny XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny', | |
'Canny anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny_anime', | |
'Depth anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01008016e_sdxl_depth_anime', | |
'Blur anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01016032e_sdxl_blur_anime_beta', | |
'Pose anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_pose_anime', | |
'Replicate anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_replicate_anime_v2', | |
} | |
models = {} | |
all_models = {} | |
all_models.update(predefined_sd15) | |
all_models.update(predefined_sdxl) | |
cache_dir = 'models/control/lite' | |
def find_models(): | |
path = os.path.join(opts.control_dir, 'lite') | |
files = listdir(path) | |
files = [f for f in files if f.endswith('.safetensors')] | |
downloaded_models = {} | |
for f in files: | |
basename = os.path.splitext(os.path.relpath(f, path))[0] | |
downloaded_models[basename] = os.path.join(path, f) | |
all_models.update(downloaded_models) | |
return downloaded_models | |
def list_models(refresh=False): | |
import modules.shared | |
global models # pylint: disable=global-statement | |
if not refresh and len(models) > 0: | |
return models | |
models = {} | |
if modules.shared.sd_model_type == 'none': | |
models = ['None'] | |
elif modules.shared.sd_model_type == 'sdxl': | |
models = ['None'] + sorted(predefined_sdxl) + sorted(find_models()) | |
elif modules.shared.sd_model_type == 'sd': | |
models = ['None'] + sorted(predefined_sd15) + sorted(find_models()) | |
else: | |
log.warning(f'Control {what} model list failed: unknown model type') | |
models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models()) | |
debug(f'Control list {what}: path={cache_dir} models={models}') | |
return models | |
class ControlLLLite(): | |
def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): | |
self.model: ControlNetLLLite = None | |
self.model_id: str = model_id | |
self.device = device | |
self.dtype = dtype | |
self.load_config = { 'cache_dir': cache_dir } | |
if load_config is not None: | |
self.load_config.update(load_config) | |
if model_id is not None: | |
self.load() | |
def reset(self): | |
if self.model is not None: | |
debug(f'Control {what} model unloaded') | |
self.model = None | |
self.model_id = None | |
def load(self, model_id: str = None) -> str: | |
try: | |
t0 = time.time() | |
model_id = model_id or self.model_id | |
if model_id is None or model_id == 'None': | |
self.reset() | |
return | |
model_path = all_models[model_id] | |
if model_path == '': | |
return | |
if model_path is None: | |
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') | |
return | |
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}') | |
if model_path.endswith('.safetensors'): | |
self.model = ControlNetLLLite(model_path) | |
else: | |
import huggingface_hub as hf | |
folder, filename = os.path.split(model_path) | |
model_path = hf.hf_hub_download(repo_id=folder, filename=f'{filename}.safetensors', cache_dir=cache_dir) | |
self.model = ControlNetLLLite(model_path) | |
if self.device is not None: | |
self.model.to(self.device) | |
if self.dtype is not None: | |
self.model.to(self.dtype) | |
t1 = time.time() | |
self.model_id = model_id | |
log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') | |
return f'{what} loaded model: {model_id}' | |
except Exception as e: | |
log.error(f'Control {what} model load failed: id="{model_id}" error={e}') | |
errors.display(e, f'Control {what} load') | |
return f'{what} failed to load model: {model_id}' | |
class ControlLLitePipeline(): | |
def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline]): | |
self.pipeline = pipeline | |
self.nets = [] | |
def apply(self, controlnet: Union[ControlNetLLLite, list[ControlNetLLLite]], image, conditioning): | |
if image is None: | |
return | |
self.nets = [controlnet] if isinstance(controlnet, ControlNetLLLite) else controlnet | |
debug(f'Control {what} apply: models={len(self.nets)} image={image} conditioning={conditioning}') | |
weight = [conditioning] if isinstance(conditioning, float) else conditioning | |
images = [image] if isinstance(image, Image.Image) else image | |
images = [i.convert('RGB') for i in images] | |
for i, cn in enumerate(self.nets): | |
cn.apply(pipe=self.pipeline, cond=np.asarray(images[i % len(images)]), weight=weight[i % len(weight)]) | |
def restore(self): | |
from modules.control.units.lite_model import clear_all_lllite | |
clear_all_lllite() | |
self.nets = [] | |