Spaces:
Runtime error
Runtime error
File size: 5,696 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import time
from typing import Union
import numpy as np
from PIL import Image
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from modules.shared import log, opts, listdir
from modules import errors
from modules.control.units.lite_model import ControlNetLLLite
what = 'ControlLLLite'
debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: CONTROL')
predefined_sd15 = {
}
predefined_sdxl = {
'Canny XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny',
'Canny anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny_anime',
'Depth anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01008016e_sdxl_depth_anime',
'Blur anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01016032e_sdxl_blur_anime_beta',
'Pose anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_pose_anime',
'Replicate anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_replicate_anime_v2',
}
models = {}
all_models = {}
all_models.update(predefined_sd15)
all_models.update(predefined_sdxl)
cache_dir = 'models/control/lite'
def find_models():
path = os.path.join(opts.control_dir, 'lite')
files = listdir(path)
files = [f for f in files if f.endswith('.safetensors')]
downloaded_models = {}
for f in files:
basename = os.path.splitext(os.path.relpath(f, path))[0]
downloaded_models[basename] = os.path.join(path, f)
all_models.update(downloaded_models)
return downloaded_models
def list_models(refresh=False):
import modules.shared
global models # pylint: disable=global-statement
if not refresh and len(models) > 0:
return models
models = {}
if modules.shared.sd_model_type == 'none':
models = ['None']
elif modules.shared.sd_model_type == 'sdxl':
models = ['None'] + sorted(predefined_sdxl) + sorted(find_models())
elif modules.shared.sd_model_type == 'sd':
models = ['None'] + sorted(predefined_sd15) + sorted(find_models())
else:
log.warning(f'Control {what} model list failed: unknown model type')
models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models())
debug(f'Control list {what}: path={cache_dir} models={models}')
return models
class ControlLLLite():
def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None):
self.model: ControlNetLLLite = None
self.model_id: str = model_id
self.device = device
self.dtype = dtype
self.load_config = { 'cache_dir': cache_dir }
if load_config is not None:
self.load_config.update(load_config)
if model_id is not None:
self.load()
def reset(self):
if self.model is not None:
debug(f'Control {what} model unloaded')
self.model = None
self.model_id = None
def load(self, model_id: str = None) -> str:
try:
t0 = time.time()
model_id = model_id or self.model_id
if model_id is None or model_id == 'None':
self.reset()
return
model_path = all_models[model_id]
if model_path == '':
return
if model_path is None:
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
return
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}')
if model_path.endswith('.safetensors'):
self.model = ControlNetLLLite(model_path)
else:
import huggingface_hub as hf
folder, filename = os.path.split(model_path)
model_path = hf.hf_hub_download(repo_id=folder, filename=f'{filename}.safetensors', cache_dir=cache_dir)
self.model = ControlNetLLLite(model_path)
if self.device is not None:
self.model.to(self.device)
if self.dtype is not None:
self.model.to(self.dtype)
t1 = time.time()
self.model_id = model_id
log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}')
return f'{what} loaded model: {model_id}'
except Exception as e:
log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
errors.display(e, f'Control {what} load')
return f'{what} failed to load model: {model_id}'
class ControlLLitePipeline():
def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline]):
self.pipeline = pipeline
self.nets = []
def apply(self, controlnet: Union[ControlNetLLLite, list[ControlNetLLLite]], image, conditioning):
if image is None:
return
self.nets = [controlnet] if isinstance(controlnet, ControlNetLLLite) else controlnet
debug(f'Control {what} apply: models={len(self.nets)} image={image} conditioning={conditioning}')
weight = [conditioning] if isinstance(conditioning, float) else conditioning
images = [image] if isinstance(image, Image.Image) else image
images = [i.convert('RGB') for i in images]
for i, cn in enumerate(self.nets):
cn.apply(pipe=self.pipeline, cond=np.asarray(images[i % len(images)]), weight=weight[i % len(weight)])
def restore(self):
from modules.control.units.lite_model import clear_all_lllite
clear_all_lllite()
self.nets = []
|