import os.path import stat from collections import OrderedDict from modules import shared, scripts, sd_models from modules.paths import models_path from scripts.enums import StableDiffusionVersion from scripts.supported_preprocessor import Preprocessor from typing import Dict, Tuple, List CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"] cn_models_dir = os.path.join(models_path, "ControlNet") cn_models_dir_old = os.path.join(scripts.basedir(), "models") cn_models = OrderedDict() # "My_Lora(abcd1234)" -> C:/path/to/model.safetensors cn_models_names = {} # "my_lora" -> "My_Lora(abcd1234)" default_detectedmap_dir = os.path.join("detected_maps") script_dir = scripts.basedir() os.makedirs(cn_models_dir, exist_ok=True) def traverse_all_files(curr_path, model_list): f_list = [ (os.path.join(curr_path, entry.name), entry.stat()) for entry in os.scandir(curr_path) if os.path.isdir(curr_path) ] for f_info in f_list: fname, fstat = f_info if os.path.splitext(fname)[1] in CN_MODEL_EXTS: model_list.append(f_info) elif stat.S_ISDIR(fstat.st_mode): model_list = traverse_all_files(fname, model_list) return model_list def get_all_models(sort_by, filter_by, path): res = OrderedDict() fileinfos = traverse_all_files(path, []) filter_by = filter_by.strip(" ") if len(filter_by) != 0: fileinfos = [x for x in fileinfos if filter_by.lower() in os.path.basename(x[0]).lower()] if sort_by == "name": fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0])) elif sort_by == "date": fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime) elif sort_by == "path name": fileinfos = sorted(fileinfos) for finfo in fileinfos: filename = finfo[0] name = os.path.splitext(os.path.basename(filename))[0] # Prevent a hypothetical "None.pt" from being listed. if name != "None": res[name + f" [{sd_models.model_hash(filename)}]"] = filename return res def update_cn_models(): cn_models.clear() ext_dirs = (shared.opts.data.get("control_net_models_path", None), getattr(shared.cmd_opts, 'controlnet_dir', None)) extra_lora_paths = (extra_lora_path for extra_lora_path in ext_dirs if extra_lora_path is not None and os.path.exists(extra_lora_path)) paths = [cn_models_dir, cn_models_dir_old, *extra_lora_paths] for path in paths: sort_by = shared.opts.data.get( "control_net_models_sort_models_by", "name") filter_by = shared.opts.data.get("control_net_models_name_filter", "") found = get_all_models(sort_by, filter_by, path) cn_models.update({**found, **cn_models}) # insert "None" at the beginning of `cn_models` in-place cn_models_copy = OrderedDict(cn_models) cn_models.clear() cn_models.update({**{"None": None}, **cn_models_copy}) cn_models_names.clear() for name_and_hash, filename in cn_models.items(): if filename is None: continue name = os.path.splitext(os.path.basename(filename))[0].lower() cn_models_names[name] = name_and_hash def get_sd_version() -> StableDiffusionVersion: if hasattr(shared.sd_model, 'is_sdxl'): if shared.sd_model.is_sdxl: return StableDiffusionVersion.SDXL elif shared.sd_model.is_sd2: return StableDiffusionVersion.SD2x elif shared.sd_model.is_sd1: return StableDiffusionVersion.SD1x else: return StableDiffusionVersion.UNKNOWN # backward compability for webui < 1.5.0 else: if hasattr(shared.sd_model, 'conditioner'): return StableDiffusionVersion.SDXL elif hasattr(shared.sd_model.cond_stage_model, 'model'): return StableDiffusionVersion.SD2x else: return StableDiffusionVersion.SD1x def select_control_type( control_type: str, sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN, cn_models: Dict = cn_models, # Override or testing ) -> Tuple[List[str], List[str], str, str]: pattern = control_type.lower() all_models = list(cn_models.keys()) if pattern == "all": return [ [p.label for p in Preprocessor.get_sorted_preprocessors()], all_models, 'none', #default option "None" #default model ] filtered_model_list = [ model for model in all_models if model.lower() == "none" or (( pattern in model.lower() or any(a in model.lower() for a in Preprocessor.tag_to_filters(control_type)) ) and ( sd_version.is_compatible_with(StableDiffusionVersion.detect_from_model_name(model)) )) ] assert len(filtered_model_list) > 0, "'None' model should always be available." if len(filtered_model_list) == 1: default_model = "None" else: default_model = filtered_model_list[1] for x in filtered_model_list: if "11" in x.split("[")[0]: default_model = x break return ( [p.label for p in Preprocessor.get_filtered_preprocessors(control_type)], filtered_model_list, Preprocessor.get_default_preprocessor(control_type).label, default_model )