File size: 11,067 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import os
import time
from typing import Union
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline
from modules.control.units import detect
from modules.shared import log, opts, listdir
from modules import errors, sd_models


what = 'ControlNet'
debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: CONTROL')
predefined_sd15 = {
    'Canny': "lllyasviel/control_v11p_sd15_canny",
    'Depth': "lllyasviel/control_v11f1p_sd15_depth",
    'HED': "lllyasviel/sd-controlnet-hed",
    'IP2P': "lllyasviel/control_v11e_sd15_ip2p",
    'LineArt': "lllyasviel/control_v11p_sd15_lineart",
    'LineArt Anime': "lllyasviel/control_v11p_sd15s2_lineart_anime",
    'MLDS': "lllyasviel/control_v11p_sd15_mlsd",
    'NormalBae': "lllyasviel/control_v11p_sd15_normalbae",
    'OpenPose': "lllyasviel/control_v11p_sd15_openpose",
    'Scribble': "lllyasviel/control_v11p_sd15_scribble",
    'Segment': "lllyasviel/control_v11p_sd15_seg",
    'Shuffle': "lllyasviel/control_v11e_sd15_shuffle",
    'SoftEdge': "lllyasviel/control_v11p_sd15_softedge",
    'Tile': "lllyasviel/control_v11f1e_sd15_tile",
    'Depth Anything': 'vladmandic/depth-anything',
    'Canny FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_canny.safetensors',
    'Inpaint FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_inpaint.safetensors',
    'LineArt Anime FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_animeline.safetensors',
    'LineArt FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_lineart.safetensors',
    'MLSD FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_mlsd.safetensors',
    'NormalBae FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_normal.safetensors',
    'OpenPose FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_openpose.safetensors',
    'Pix2Pix FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_pix2pix.safetensors',
    'Scribble FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_scribble.safetensors',
    'Segment FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_seg.safetensors',
    'Shuffle FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_shuffle.safetensors',
    'SoftEdge FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_softedge.safetensors',
    'Tile FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_tileE.safetensors',
    'CiaraRowles TemporalNet': "CiaraRowles/TemporalNet",
    'Ciaochaos Recolor': 'ioclab/control_v1p_sd15_brightness',
    'Ciaochaos Illumination': 'ioclab/control_v1u_sd15_illumination/illumination20000.safetensors',
}
predefined_sdxl = {
    'Canny Small XL': 'diffusers/controlnet-canny-sdxl-1.0-small',
    'Canny Mid XL': 'diffusers/controlnet-canny-sdxl-1.0-mid',
    'Canny XL': 'diffusers/controlnet-canny-sdxl-1.0',
    'Depth Zoe XL': 'diffusers/controlnet-zoe-depth-sdxl-1.0',
    'Depth Mid XL': 'diffusers/controlnet-depth-sdxl-1.0-mid',
    'OpenPose XL': 'thibaud/controlnet-openpose-sdxl-1.0',
    # 'StabilityAI Canny R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-canny-rank128.safetensors',
    # 'StabilityAI Depth R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-depth-rank128.safetensors',
    # 'StabilityAI Recolor R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-recolor-rank128.safetensors',
    # 'StabilityAI Sketch R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-sketch-rank128-metadata.safetensors',
    # 'StabilityAI Canny R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-canny-rank256.safetensors',
    # 'StabilityAI Depth R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-depth-rank256.safetensors',
    # 'StabilityAI Recolor R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-recolor-rank256.safetensors',
    # 'StabilityAI Sketch R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-sketch-rank256.safetensors',
}
models = {}
all_models = {}
all_models.update(predefined_sd15)
all_models.update(predefined_sdxl)
cache_dir = 'models/control/controlnet'


def find_models():
    path = os.path.join(opts.control_dir, 'controlnet')
    files = listdir(path)
    files = [f for f in files if f.endswith('.safetensors')]
    downloaded_models = {}
    for f in files:
        basename = os.path.splitext(os.path.relpath(f, path))[0]
        downloaded_models[basename] = os.path.join(path, f)
    all_models.update(downloaded_models)
    return downloaded_models


def list_models(refresh=False):
    import modules.shared
    global models # pylint: disable=global-statement
    if not refresh and len(models) > 0:
        return models
    models = {}
    if modules.shared.sd_model_type == 'none':
        models = ['None']
    elif modules.shared.sd_model_type == 'sdxl':
        models = ['None'] + list(predefined_sdxl) + sorted(find_models())
    elif modules.shared.sd_model_type == 'sd':
        models = ['None'] + list(predefined_sd15) + sorted(find_models())
    else:
        log.warning(f'Control {what} model list failed: unknown model type')
        models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models())
    debug(f'Control list {what}: path={cache_dir} models={models}')
    return models


class ControlNet():
    def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None):
        self.model: ControlNetModel = None
        self.model_id: str = model_id
        self.device = device
        self.dtype = dtype
        self.load_config = { 'cache_dir': cache_dir }
        if load_config is not None:
            self.load_config.update(load_config)
        if model_id is not None:
            self.load()

    def reset(self):
        if self.model is not None:
            debug(f'Control {what} model unloaded')
        self.model = None
        self.model_id = None

    def load_safetensors(self, model_path):
        name = os.path.splitext(model_path)[0]
        config_path = None
        if not os.path.exists(model_path):
            import huggingface_hub as hf
            parts = model_path.split('/')
            repo_id = f'{parts[0]}/{parts[1]}'
            filename = os.path.splitext('/'.join(parts[2:]))[0]
            model_path = hf.hf_hub_download(repo_id=repo_id, filename=f'{filename}.safetensors', cache_dir=cache_dir)
            if config_path is None:
                try:
                    config_path = hf.hf_hub_download(repo_id=repo_id, filename=f'{filename}.yaml', cache_dir=cache_dir)
                except Exception:
                    pass # no yaml file
            if config_path is None:
                try:
                    config_path = hf.hf_hub_download(repo_id=repo_id, filename=f'{filename}.json', cache_dir=cache_dir)
                except Exception:
                    pass # no yaml file
        elif os.path.exists(name + '.yaml'):
            config_path = f'{name}.yaml'
        elif os.path.exists(name + '.json'):
            config_path = f'{name}.json'
        if config_path is not None:
            self.load_config['original_config_file '] = config_path
        self.model = ControlNetModel.from_single_file(model_path, **self.load_config)

    def load(self, model_id: str = None) -> str:
        try:
            t0 = time.time()
            model_id = model_id or self.model_id
            if model_id is None or model_id == 'None':
                self.reset()
                return
            model_path = all_models[model_id]
            if model_path == '':
                return
            if model_path is None:
                log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
                return
            log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"')
            if model_path.endswith('.safetensors'):
                self.load_safetensors(model_path)
            else:
                self.model = ControlNetModel.from_pretrained(model_path, **self.load_config)
            if self.device is not None:
                self.model.to(self.device)
            if self.dtype is not None:
                self.model.to(self.dtype)
            t1 = time.time()
            self.model_id = model_id
            log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}')
            return f'{what} loaded model: {model_id}'
        except Exception as e:
            log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
            errors.display(e, f'Control {what} load')
            return f'{what} failed to load model: {model_id}'


class ControlNetPipeline():
    def __init__(self, controlnet: Union[ControlNetModel, list[ControlNetModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
        t0 = time.time()
        self.orig_pipeline = pipeline
        self.pipeline = None
        if pipeline is None:
            log.error('Control model pipeline: model not loaded')
            return
        elif detect.is_sdxl(pipeline):
            self.pipeline = StableDiffusionXLControlNetPipeline(
                vae=pipeline.vae,
                text_encoder=pipeline.text_encoder,
                text_encoder_2=pipeline.text_encoder_2,
                tokenizer=pipeline.tokenizer,
                tokenizer_2=pipeline.tokenizer_2,
                unet=pipeline.unet,
                scheduler=pipeline.scheduler,
                feature_extractor=getattr(pipeline, 'feature_extractor', None),
                controlnet=controlnet, # can be a list
            )
            sd_models.move_model(self.pipeline, pipeline.device)
        elif detect.is_sd15(pipeline):
            self.pipeline = StableDiffusionControlNetPipeline(
                vae=pipeline.vae,
                text_encoder=pipeline.text_encoder,
                tokenizer=pipeline.tokenizer,
                unet=pipeline.unet,
                scheduler=pipeline.scheduler,
                feature_extractor=getattr(pipeline, 'feature_extractor', None),
                requires_safety_checker=False,
                safety_checker=None,
                controlnet=controlnet, # can be a list
            )
            sd_models.move_model(self.pipeline, pipeline.device)
        else:
            log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type')
            return
        if dtype is not None and self.pipeline is not None:
            self.pipeline = self.pipeline.to(dtype)
        t1 = time.time()
        if self.pipeline is not None:
            log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}')
        else:
            log.error(f'Control {what} pipeline: not initialized')

    def restore(self):
        self.pipeline = None
        return self.orig_pipeline