Spaces:
Runtime error
Runtime error
import os | |
import time | |
from typing import Union | |
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline | |
from modules.control.units import detect | |
from modules.shared import log, opts, listdir | |
from modules import errors, sd_models | |
what = 'ControlNet' | |
debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None | |
debug('Trace: CONTROL') | |
predefined_sd15 = { | |
'Canny': "lllyasviel/control_v11p_sd15_canny", | |
'Depth': "lllyasviel/control_v11f1p_sd15_depth", | |
'HED': "lllyasviel/sd-controlnet-hed", | |
'IP2P': "lllyasviel/control_v11e_sd15_ip2p", | |
'LineArt': "lllyasviel/control_v11p_sd15_lineart", | |
'LineArt Anime': "lllyasviel/control_v11p_sd15s2_lineart_anime", | |
'MLDS': "lllyasviel/control_v11p_sd15_mlsd", | |
'NormalBae': "lllyasviel/control_v11p_sd15_normalbae", | |
'OpenPose': "lllyasviel/control_v11p_sd15_openpose", | |
'Scribble': "lllyasviel/control_v11p_sd15_scribble", | |
'Segment': "lllyasviel/control_v11p_sd15_seg", | |
'Shuffle': "lllyasviel/control_v11e_sd15_shuffle", | |
'SoftEdge': "lllyasviel/control_v11p_sd15_softedge", | |
'Tile': "lllyasviel/control_v11f1e_sd15_tile", | |
'Depth Anything': 'vladmandic/depth-anything', | |
'Canny FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_canny.safetensors', | |
'Inpaint FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_inpaint.safetensors', | |
'LineArt Anime FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_animeline.safetensors', | |
'LineArt FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_lineart.safetensors', | |
'MLSD FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_mlsd.safetensors', | |
'NormalBae FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_normal.safetensors', | |
'OpenPose FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_openpose.safetensors', | |
'Pix2Pix FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_pix2pix.safetensors', | |
'Scribble FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_scribble.safetensors', | |
'Segment FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_seg.safetensors', | |
'Shuffle FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_shuffle.safetensors', | |
'SoftEdge FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_softedge.safetensors', | |
'Tile FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_tileE.safetensors', | |
'CiaraRowles TemporalNet': "CiaraRowles/TemporalNet", | |
'Ciaochaos Recolor': 'ioclab/control_v1p_sd15_brightness', | |
'Ciaochaos Illumination': 'ioclab/control_v1u_sd15_illumination/illumination20000.safetensors', | |
} | |
predefined_sdxl = { | |
'Canny Small XL': 'diffusers/controlnet-canny-sdxl-1.0-small', | |
'Canny Mid XL': 'diffusers/controlnet-canny-sdxl-1.0-mid', | |
'Canny XL': 'diffusers/controlnet-canny-sdxl-1.0', | |
'Depth Zoe XL': 'diffusers/controlnet-zoe-depth-sdxl-1.0', | |
'Depth Mid XL': 'diffusers/controlnet-depth-sdxl-1.0-mid', | |
'OpenPose XL': 'thibaud/controlnet-openpose-sdxl-1.0', | |
# 'StabilityAI Canny R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-canny-rank128.safetensors', | |
# 'StabilityAI Depth R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-depth-rank128.safetensors', | |
# 'StabilityAI Recolor R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-recolor-rank128.safetensors', | |
# 'StabilityAI Sketch R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-sketch-rank128-metadata.safetensors', | |
# 'StabilityAI Canny R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-canny-rank256.safetensors', | |
# 'StabilityAI Depth R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-depth-rank256.safetensors', | |
# 'StabilityAI Recolor R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-recolor-rank256.safetensors', | |
# 'StabilityAI Sketch R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-sketch-rank256.safetensors', | |
} | |
models = {} | |
all_models = {} | |
all_models.update(predefined_sd15) | |
all_models.update(predefined_sdxl) | |
cache_dir = 'models/control/controlnet' | |
def find_models(): | |
path = os.path.join(opts.control_dir, 'controlnet') | |
files = listdir(path) | |
files = [f for f in files if f.endswith('.safetensors')] | |
downloaded_models = {} | |
for f in files: | |
basename = os.path.splitext(os.path.relpath(f, path))[0] | |
downloaded_models[basename] = os.path.join(path, f) | |
all_models.update(downloaded_models) | |
return downloaded_models | |
def list_models(refresh=False): | |
import modules.shared | |
global models # pylint: disable=global-statement | |
if not refresh and len(models) > 0: | |
return models | |
models = {} | |
if modules.shared.sd_model_type == 'none': | |
models = ['None'] | |
elif modules.shared.sd_model_type == 'sdxl': | |
models = ['None'] + list(predefined_sdxl) + sorted(find_models()) | |
elif modules.shared.sd_model_type == 'sd': | |
models = ['None'] + list(predefined_sd15) + sorted(find_models()) | |
else: | |
log.warning(f'Control {what} model list failed: unknown model type') | |
models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models()) | |
debug(f'Control list {what}: path={cache_dir} models={models}') | |
return models | |
class ControlNet(): | |
def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): | |
self.model: ControlNetModel = None | |
self.model_id: str = model_id | |
self.device = device | |
self.dtype = dtype | |
self.load_config = { 'cache_dir': cache_dir } | |
if load_config is not None: | |
self.load_config.update(load_config) | |
if model_id is not None: | |
self.load() | |
def reset(self): | |
if self.model is not None: | |
debug(f'Control {what} model unloaded') | |
self.model = None | |
self.model_id = None | |
def load_safetensors(self, model_path): | |
name = os.path.splitext(model_path)[0] | |
config_path = None | |
if not os.path.exists(model_path): | |
import huggingface_hub as hf | |
parts = model_path.split('/') | |
repo_id = f'{parts[0]}/{parts[1]}' | |
filename = os.path.splitext('/'.join(parts[2:]))[0] | |
model_path = hf.hf_hub_download(repo_id=repo_id, filename=f'{filename}.safetensors', cache_dir=cache_dir) | |
if config_path is None: | |
try: | |
config_path = hf.hf_hub_download(repo_id=repo_id, filename=f'{filename}.yaml', cache_dir=cache_dir) | |
except Exception: | |
pass # no yaml file | |
if config_path is None: | |
try: | |
config_path = hf.hf_hub_download(repo_id=repo_id, filename=f'{filename}.json', cache_dir=cache_dir) | |
except Exception: | |
pass # no yaml file | |
elif os.path.exists(name + '.yaml'): | |
config_path = f'{name}.yaml' | |
elif os.path.exists(name + '.json'): | |
config_path = f'{name}.json' | |
if config_path is not None: | |
self.load_config['original_config_file '] = config_path | |
self.model = ControlNetModel.from_single_file(model_path, **self.load_config) | |
def load(self, model_id: str = None) -> str: | |
try: | |
t0 = time.time() | |
model_id = model_id or self.model_id | |
if model_id is None or model_id == 'None': | |
self.reset() | |
return | |
model_path = all_models[model_id] | |
if model_path == '': | |
return | |
if model_path is None: | |
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') | |
return | |
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') | |
if model_path.endswith('.safetensors'): | |
self.load_safetensors(model_path) | |
else: | |
self.model = ControlNetModel.from_pretrained(model_path, **self.load_config) | |
if self.device is not None: | |
self.model.to(self.device) | |
if self.dtype is not None: | |
self.model.to(self.dtype) | |
t1 = time.time() | |
self.model_id = model_id | |
log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') | |
return f'{what} loaded model: {model_id}' | |
except Exception as e: | |
log.error(f'Control {what} model load failed: id="{model_id}" error={e}') | |
errors.display(e, f'Control {what} load') | |
return f'{what} failed to load model: {model_id}' | |
class ControlNetPipeline(): | |
def __init__(self, controlnet: Union[ControlNetModel, list[ControlNetModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None): | |
t0 = time.time() | |
self.orig_pipeline = pipeline | |
self.pipeline = None | |
if pipeline is None: | |
log.error('Control model pipeline: model not loaded') | |
return | |
elif detect.is_sdxl(pipeline): | |
self.pipeline = StableDiffusionXLControlNetPipeline( | |
vae=pipeline.vae, | |
text_encoder=pipeline.text_encoder, | |
text_encoder_2=pipeline.text_encoder_2, | |
tokenizer=pipeline.tokenizer, | |
tokenizer_2=pipeline.tokenizer_2, | |
unet=pipeline.unet, | |
scheduler=pipeline.scheduler, | |
feature_extractor=getattr(pipeline, 'feature_extractor', None), | |
controlnet=controlnet, # can be a list | |
) | |
sd_models.move_model(self.pipeline, pipeline.device) | |
elif detect.is_sd15(pipeline): | |
self.pipeline = StableDiffusionControlNetPipeline( | |
vae=pipeline.vae, | |
text_encoder=pipeline.text_encoder, | |
tokenizer=pipeline.tokenizer, | |
unet=pipeline.unet, | |
scheduler=pipeline.scheduler, | |
feature_extractor=getattr(pipeline, 'feature_extractor', None), | |
requires_safety_checker=False, | |
safety_checker=None, | |
controlnet=controlnet, # can be a list | |
) | |
sd_models.move_model(self.pipeline, pipeline.device) | |
else: | |
log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') | |
return | |
if dtype is not None and self.pipeline is not None: | |
self.pipeline = self.pipeline.to(dtype) | |
t1 = time.time() | |
if self.pipeline is not None: | |
log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') | |
else: | |
log.error(f'Control {what} pipeline: not initialized') | |
def restore(self): | |
self.pipeline = None | |
return self.orig_pipeline | |