File size: 12,749 Bytes
ecc4278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union, Dict, TypedDict, Any
import numpy as np
from modules import shared
from lib_controlnet.logging import logger
from lib_controlnet.enums import InputMode, HiResFixOption
from modules.api import api

from lib_controlnet.enums import (
    InputMode,
    HiResFixOption,
)


def get_api_version() -> int:
    return 2


class ControlMode(Enum):
    """
    The improved guess mode.
    """

    BALANCED = "Balanced"
    PROMPT = "My prompt is more important"
    CONTROL = "ControlNet is more important"


class BatchOption(Enum):
    DEFAULT = "All ControlNet units for all images in a batch"
    SEPARATE = "Each ControlNet unit for each image in a batch"


class ResizeMode(Enum):
    """
    Resize modes for ControlNet input images.
    """

    RESIZE = "Just Resize"
    INNER_FIT = "Crop and Resize"
    OUTER_FIT = "Resize and Fill"

    def int_value(self):
        if self == ResizeMode.RESIZE:
            return 0
        elif self == ResizeMode.INNER_FIT:
            return 1
        elif self == ResizeMode.OUTER_FIT:
            return 2
        assert False, "NOTREACHED"


resize_mode_aliases = {
    'Inner Fit (Scale to Fit)': 'Crop and Resize',
    'Outer Fit (Shrink to Fit)': 'Resize and Fill',
    'Scale to Fit (Inner Fit)': 'Crop and Resize',
    'Envelope (Outer Fit)': 'Resize and Fill',
}


def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode:
    if isinstance(value, str):
        if value.startswith("ResizeMode."):
            _, field = value.split(".")
            return getattr(ResizeMode, field)
        return ResizeMode(resize_mode_aliases.get(value, value))
    elif isinstance(value, int):
        assert value >= 0
        if value == 3:  # 'Just Resize (Latent upscale)'
            return ResizeMode.RESIZE
        try:
            return list(ResizeMode)[value]
        except IndexError:
            print(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.')
            return ResizeMode.RESIZE
    elif isinstance(value, ResizeMode):
        return value
    else:
        raise TypeError(f"ResizeMode value must be str, int, or ResizeMode, not {type(value)}")


def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode:
    if isinstance(value, str):
        try:
            return ControlMode(value)
        except ValueError:
            print(f'Unrecognized ControlMode string value "{value}". Fall back to BALANCED.')
            return ControlMode.BALANCED
    elif isinstance(value, int):
        try:
            return [e for e in ControlMode][value]
        except IndexError:
            print(f'Unrecognized ControlMode int value {value}. Fall back to BALANCED.')
            return ControlMode.BALANCED
    elif isinstance(value, ControlMode):
        return value
    else:
        raise TypeError(f"ControlMode value must be str, int, or ControlMode, not {type(value)}")


def visualize_inpaint_mask(img):
    if img.ndim == 3 and img.shape[2] == 4:
        result = img.copy()
        mask = result[:, :, 3]
        mask = 255 - mask // 2
        result[:, :, 3] = mask
        return np.ascontiguousarray(result.copy())
    return img


def pixel_perfect_resolution(
        image: np.ndarray,
        target_H: int,
        target_W: int,
        resize_mode: ResizeMode,
) -> int:
    """
    Calculate the estimated resolution for resizing an image while preserving aspect ratio.

    The function first calculates scaling factors for height and width of the image based on the target
    height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
    scaling factor to estimate the new resolution.

    If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
    fits within the target dimensions, potentially leaving some empty space.

    If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
    dimensions are fully filled, potentially cropping the image.

    After calculating the estimated resolution, the function prints some debugging information.

    Args:
        image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
        target_H (int): The target height for the image.
        target_W (int): The target width for the image.
        resize_mode (ResizeMode): The mode for resizing.

    Returns:
        int: The estimated resolution after resizing.
    """
    raw_H, raw_W, _ = image.shape

    k0 = float(target_H) / float(raw_H)
    k1 = float(target_W) / float(raw_W)

    if resize_mode == ResizeMode.OUTER_FIT:
        estimation = min(k0, k1) * float(min(raw_H, raw_W))
    else:
        estimation = max(k0, k1) * float(min(raw_H, raw_W))

    logger.debug(f"Pixel Perfect Computation:")
    logger.debug(f"resize_mode = {resize_mode}")
    logger.debug(f"raw_H = {raw_H}")
    logger.debug(f"raw_W = {raw_W}")
    logger.debug(f"target_H = {target_H}")
    logger.debug(f"target_W = {target_W}")
    logger.debug(f"estimation = {estimation}")

    return int(np.round(estimation))


class GradioImageMaskPair(TypedDict):
    """Represents the dict object from Gradio's image component if `tool="sketch"`
    is specified.
    {
        "image": np.ndarray,
        "mask": np.ndarray,
    }
    """
    image: np.ndarray
    mask: np.ndarray


@dataclass
class ControlNetUnit:
    """Represents an entire ControlNet processing unit.

    To add a new field to this class
    ## If the new field can be specified on UI, you need to
    - Add a new field of the same name in constructor of `ControlNetUiGroup`
    - Initialize the new `ControlNetUiGroup` field in `ControlNetUiGroup.render`
      as a Gradio `IOComponent`.
    - Add the new `ControlNetUiGroup` field to `unit_args` in
      `ControlNetUiGroup.render`. The order of parameters matters.

    ## If the new field needs to appear in infotext, you need to
    - Add a new item in `ControlNetUnit.infotext_fields`.
    API-only fields cannot appear in infotext.
    """
    # Following fields should only be used in the UI.
    # ====== Start of UI only fields ======
    # Specifies the input mode for the unit, defaulting to a simple mode.
    input_mode: InputMode = InputMode.SIMPLE
    # Determines whether to use the preview image as input; defaults to False.
    use_preview_as_input: bool = False
    # Directory path for batch processing of images.
    batch_image_dir: str = ''
    # Directory path for batch processing of masks.
    batch_mask_dir: str = ''
    # Optional list of gallery images for batch input; defaults to None.
    batch_input_gallery: Optional[List[str]] = None
    # Optional list of gallery masks for batch processing; defaults to None.
    batch_mask_gallery: Optional[List[str]] = None
    # Optional list of gallery images for multi-inputs; defaults to None.
    multi_inputs_gallery: Optional[List[str]] = None
    # Holds the preview image as a NumPy array; defaults to None.
    generated_image: Optional[np.ndarray] = None
    # ====== End of UI only fields ======

    # Following fields are used in both the API and the UI.
    # Holds the mask image; defaults to None.
    mask_image: Optional[GradioImageMaskPair] = None
    # Specifies how this unit should be applied in each pass of high-resolution fix.
    # Ignored if high-resolution fix is not enabled.
    hr_option: HiResFixOption = HiResFixOption.BOTH
    # Indicates whether the unit is enabled; defaults to True.
    enabled: bool = True
    # Name of the module being used; defaults to "None".
    module: str = "None"
    # Name of the model being used; defaults to "None".
    model: str = "None"
    # Weight of the unit in the overall processing; defaults to 1.0.
    weight: float = 1.0
    # Optional image for input; defaults to None.
    image: Optional[GradioImageMaskPair] = None
    # Specifies the mode of image resizing; defaults to inner fit.
    resize_mode: ResizeMode = ResizeMode.INNER_FIT
    # Resolution for processing by the unit; defaults to -1 (unspecified).
    processor_res: int = -1
    # Threshold A for processing; defaults to -1 (unspecified).
    threshold_a: float = -1
    # Threshold B for processing; defaults to -1 (unspecified).
    threshold_b: float = -1
    # Start value for guidance; defaults to 0.0.
    guidance_start: float = 0.0
    # End value for guidance; defaults to 1.0.
    guidance_end: float = 1.0
    # Enables pixel-perfect processing; defaults to False.
    pixel_perfect: bool = False
    # Control mode for the unit; defaults to balanced.
    control_mode: ControlMode = ControlMode.BALANCED
    # Weight for each layer of ControlNet params.
    # For ControlNet:
    # - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
    # - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
    # For T2IAdapter
    # - SD1.5: 5 weights (4 encoder block + 1 middle block)
    # - SDXL: 4 weights (3 encoder block + 1 middle block)
    # For IPAdapter
    # - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
    # - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
    # Note1: Setting advanced weighting will disable `soft_injection`, i.e.
    # It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
    # Note2: The field `weight` is still used in some places, e.g. reference_only,
    # even advanced_weighting is set.
    advanced_weighting: Optional[List[float]] = None

    # Following fields should only be used in the API.
    # ====== Start of API only fields ======
    # Whether to save the detected map for this unit; defaults to True.
    save_detected_map: bool = True
    # ====== End of API only fields ======

    @staticmethod
    def infotext_fields():
        """Fields that should be included in infotext.
        You should define a Gradio element with exact same name in ControlNetUiGroup
        as well, so that infotext can wire the value to correct field when pasting
        infotext.
        """
        return (
            "module",
            "model",
            "weight",
            "resize_mode",
            "processor_res",
            "threshold_a",
            "threshold_b",
            "guidance_start",
            "guidance_end",
            "pixel_perfect",
            "control_mode",
            "hr_option",
        )

    @staticmethod
    def from_dict(d: Dict) -> "ControlNetUnit":
        """Create ControlNetUnit from dict. This is primarily used to convert
        API json dict to ControlNetUnit."""
        unit = ControlNetUnit(
            **{k: v for k, v in d.items() if k in vars(ControlNetUnit)}
        )
        if isinstance(unit.image, str):
            img = np.array(api.decode_base64_to_image(unit.image)).astype('uint8')
            unit.image = {
                "image": img,
                "mask": np.zeros_like(img),
            }
        if isinstance(unit.mask_image, str):
            mask = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8')
            if unit.image is not None:
                # Attach mask on image if ControlNet has input image.
                assert isinstance(unit.image, dict)
                unit.image["mask"] = mask
                unit.mask_image = None
            else:
                # Otherwise, wire to standalone mask.
                # This happens in img2img when using A1111 img2img input.
                unit.mask_image = {
                    "image": mask,
                    "mask": np.zeros_like(mask),
                }
        # Convert strings to enums.
        unit.input_mode = InputMode(unit.input_mode)
        unit.hr_option = HiResFixOption.from_value(unit.hr_option)
        unit.resize_mode = resize_mode_from_value(unit.resize_mode)
        unit.control_mode = control_mode_from_value(unit.control_mode)
        return unit


# Backward Compatible
UiControlNetUnit = ControlNetUnit


def to_base64_nparray(encoding: str):
    """
    Convert a base64 image into the image type the extension uses
    """

    return np.array(api.decode_base64_to_image(encoding)).astype('uint8')


def get_max_models_num():
    """
    Fetch the maximum number of allowed ControlNet models.
    """

    max_models_num = shared.opts.data.get("control_net_unit_count", 3)
    return max_models_num

def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit:
    """
    Convert different types to processing unit.
    Backward Compatible
    """

    if isinstance(unit, dict):
        unit = ControlNetUnit.from_dict(unit)

    return unit