|
from dataclasses import dataclass |
|
from enum import Enum |
|
from typing import List, Optional, Union, Dict, TypedDict, Any |
|
import numpy as np |
|
from modules import shared |
|
from lib_controlnet.logging import logger |
|
from lib_controlnet.enums import InputMode, HiResFixOption |
|
from modules.api import api |
|
|
|
from lib_controlnet.enums import ( |
|
InputMode, |
|
HiResFixOption, |
|
) |
|
|
|
|
|
def get_api_version() -> int: |
|
return 2 |
|
|
|
|
|
class ControlMode(Enum): |
|
""" |
|
The improved guess mode. |
|
""" |
|
|
|
BALANCED = "Balanced" |
|
PROMPT = "My prompt is more important" |
|
CONTROL = "ControlNet is more important" |
|
|
|
|
|
class BatchOption(Enum): |
|
DEFAULT = "All ControlNet units for all images in a batch" |
|
SEPARATE = "Each ControlNet unit for each image in a batch" |
|
|
|
|
|
class ResizeMode(Enum): |
|
""" |
|
Resize modes for ControlNet input images. |
|
""" |
|
|
|
RESIZE = "Just Resize" |
|
INNER_FIT = "Crop and Resize" |
|
OUTER_FIT = "Resize and Fill" |
|
|
|
def int_value(self): |
|
if self == ResizeMode.RESIZE: |
|
return 0 |
|
elif self == ResizeMode.INNER_FIT: |
|
return 1 |
|
elif self == ResizeMode.OUTER_FIT: |
|
return 2 |
|
assert False, "NOTREACHED" |
|
|
|
|
|
resize_mode_aliases = { |
|
'Inner Fit (Scale to Fit)': 'Crop and Resize', |
|
'Outer Fit (Shrink to Fit)': 'Resize and Fill', |
|
'Scale to Fit (Inner Fit)': 'Crop and Resize', |
|
'Envelope (Outer Fit)': 'Resize and Fill', |
|
} |
|
|
|
|
|
def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode: |
|
if isinstance(value, str): |
|
if value.startswith("ResizeMode."): |
|
_, field = value.split(".") |
|
return getattr(ResizeMode, field) |
|
return ResizeMode(resize_mode_aliases.get(value, value)) |
|
elif isinstance(value, int): |
|
assert value >= 0 |
|
if value == 3: |
|
return ResizeMode.RESIZE |
|
try: |
|
return list(ResizeMode)[value] |
|
except IndexError: |
|
print(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.') |
|
return ResizeMode.RESIZE |
|
elif isinstance(value, ResizeMode): |
|
return value |
|
else: |
|
raise TypeError(f"ResizeMode value must be str, int, or ResizeMode, not {type(value)}") |
|
|
|
|
|
def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode: |
|
if isinstance(value, str): |
|
try: |
|
return ControlMode(value) |
|
except ValueError: |
|
print(f'Unrecognized ControlMode string value "{value}". Fall back to BALANCED.') |
|
return ControlMode.BALANCED |
|
elif isinstance(value, int): |
|
try: |
|
return [e for e in ControlMode][value] |
|
except IndexError: |
|
print(f'Unrecognized ControlMode int value {value}. Fall back to BALANCED.') |
|
return ControlMode.BALANCED |
|
elif isinstance(value, ControlMode): |
|
return value |
|
else: |
|
raise TypeError(f"ControlMode value must be str, int, or ControlMode, not {type(value)}") |
|
|
|
|
|
def visualize_inpaint_mask(img): |
|
if img.ndim == 3 and img.shape[2] == 4: |
|
result = img.copy() |
|
mask = result[:, :, 3] |
|
mask = 255 - mask // 2 |
|
result[:, :, 3] = mask |
|
return np.ascontiguousarray(result.copy()) |
|
return img |
|
|
|
|
|
def pixel_perfect_resolution( |
|
image: np.ndarray, |
|
target_H: int, |
|
target_W: int, |
|
resize_mode: ResizeMode, |
|
) -> int: |
|
""" |
|
Calculate the estimated resolution for resizing an image while preserving aspect ratio. |
|
|
|
The function first calculates scaling factors for height and width of the image based on the target |
|
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger |
|
scaling factor to estimate the new resolution. |
|
|
|
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image |
|
fits within the target dimensions, potentially leaving some empty space. |
|
|
|
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target |
|
dimensions are fully filled, potentially cropping the image. |
|
|
|
After calculating the estimated resolution, the function prints some debugging information. |
|
|
|
Args: |
|
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels]. |
|
target_H (int): The target height for the image. |
|
target_W (int): The target width for the image. |
|
resize_mode (ResizeMode): The mode for resizing. |
|
|
|
Returns: |
|
int: The estimated resolution after resizing. |
|
""" |
|
raw_H, raw_W, _ = image.shape |
|
|
|
k0 = float(target_H) / float(raw_H) |
|
k1 = float(target_W) / float(raw_W) |
|
|
|
if resize_mode == ResizeMode.OUTER_FIT: |
|
estimation = min(k0, k1) * float(min(raw_H, raw_W)) |
|
else: |
|
estimation = max(k0, k1) * float(min(raw_H, raw_W)) |
|
|
|
logger.debug(f"Pixel Perfect Computation:") |
|
logger.debug(f"resize_mode = {resize_mode}") |
|
logger.debug(f"raw_H = {raw_H}") |
|
logger.debug(f"raw_W = {raw_W}") |
|
logger.debug(f"target_H = {target_H}") |
|
logger.debug(f"target_W = {target_W}") |
|
logger.debug(f"estimation = {estimation}") |
|
|
|
return int(np.round(estimation)) |
|
|
|
|
|
class GradioImageMaskPair(TypedDict): |
|
"""Represents the dict object from Gradio's image component if `tool="sketch"` |
|
is specified. |
|
{ |
|
"image": np.ndarray, |
|
"mask": np.ndarray, |
|
} |
|
""" |
|
image: np.ndarray |
|
mask: np.ndarray |
|
|
|
|
|
@dataclass |
|
class ControlNetUnit: |
|
"""Represents an entire ControlNet processing unit. |
|
|
|
To add a new field to this class |
|
## If the new field can be specified on UI, you need to |
|
- Add a new field of the same name in constructor of `ControlNetUiGroup` |
|
- Initialize the new `ControlNetUiGroup` field in `ControlNetUiGroup.render` |
|
as a Gradio `IOComponent`. |
|
- Add the new `ControlNetUiGroup` field to `unit_args` in |
|
`ControlNetUiGroup.render`. The order of parameters matters. |
|
|
|
## If the new field needs to appear in infotext, you need to |
|
- Add a new item in `ControlNetUnit.infotext_fields`. |
|
API-only fields cannot appear in infotext. |
|
""" |
|
|
|
|
|
|
|
input_mode: InputMode = InputMode.SIMPLE |
|
|
|
use_preview_as_input: bool = False |
|
|
|
batch_image_dir: str = '' |
|
|
|
batch_mask_dir: str = '' |
|
|
|
batch_input_gallery: Optional[List[str]] = None |
|
|
|
batch_mask_gallery: Optional[List[str]] = None |
|
|
|
multi_inputs_gallery: Optional[List[str]] = None |
|
|
|
generated_image: Optional[np.ndarray] = None |
|
|
|
|
|
|
|
|
|
mask_image: Optional[GradioImageMaskPair] = None |
|
|
|
|
|
hr_option: HiResFixOption = HiResFixOption.BOTH |
|
|
|
enabled: bool = True |
|
|
|
module: str = "None" |
|
|
|
model: str = "None" |
|
|
|
weight: float = 1.0 |
|
|
|
image: Optional[GradioImageMaskPair] = None |
|
|
|
resize_mode: ResizeMode = ResizeMode.INNER_FIT |
|
|
|
processor_res: int = -1 |
|
|
|
threshold_a: float = -1 |
|
|
|
threshold_b: float = -1 |
|
|
|
guidance_start: float = 0.0 |
|
|
|
guidance_end: float = 1.0 |
|
|
|
pixel_perfect: bool = False |
|
|
|
control_mode: ControlMode = ControlMode.BALANCED |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
advanced_weighting: Optional[List[float]] = None |
|
|
|
|
|
|
|
|
|
save_detected_map: bool = True |
|
|
|
|
|
@staticmethod |
|
def infotext_fields(): |
|
"""Fields that should be included in infotext. |
|
You should define a Gradio element with exact same name in ControlNetUiGroup |
|
as well, so that infotext can wire the value to correct field when pasting |
|
infotext. |
|
""" |
|
return ( |
|
"module", |
|
"model", |
|
"weight", |
|
"resize_mode", |
|
"processor_res", |
|
"threshold_a", |
|
"threshold_b", |
|
"guidance_start", |
|
"guidance_end", |
|
"pixel_perfect", |
|
"control_mode", |
|
"hr_option", |
|
) |
|
|
|
@staticmethod |
|
def from_dict(d: Dict) -> "ControlNetUnit": |
|
"""Create ControlNetUnit from dict. This is primarily used to convert |
|
API json dict to ControlNetUnit.""" |
|
unit = ControlNetUnit( |
|
**{k: v for k, v in d.items() if k in vars(ControlNetUnit)} |
|
) |
|
if isinstance(unit.image, str): |
|
img = np.array(api.decode_base64_to_image(unit.image)).astype('uint8') |
|
unit.image = { |
|
"image": img, |
|
"mask": np.zeros_like(img), |
|
} |
|
if isinstance(unit.mask_image, str): |
|
mask = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8') |
|
if unit.image is not None: |
|
|
|
assert isinstance(unit.image, dict) |
|
unit.image["mask"] = mask |
|
unit.mask_image = None |
|
else: |
|
|
|
|
|
unit.mask_image = { |
|
"image": mask, |
|
"mask": np.zeros_like(mask), |
|
} |
|
|
|
unit.input_mode = InputMode(unit.input_mode) |
|
unit.hr_option = HiResFixOption.from_value(unit.hr_option) |
|
unit.resize_mode = resize_mode_from_value(unit.resize_mode) |
|
unit.control_mode = control_mode_from_value(unit.control_mode) |
|
return unit |
|
|
|
|
|
|
|
UiControlNetUnit = ControlNetUnit |
|
|
|
|
|
def to_base64_nparray(encoding: str): |
|
""" |
|
Convert a base64 image into the image type the extension uses |
|
""" |
|
|
|
return np.array(api.decode_base64_to_image(encoding)).astype('uint8') |
|
|
|
|
|
def get_max_models_num(): |
|
""" |
|
Fetch the maximum number of allowed ControlNet models. |
|
""" |
|
|
|
max_models_num = shared.opts.data.get("control_net_unit_count", 3) |
|
return max_models_num |
|
|
|
def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit: |
|
""" |
|
Convert different types to processing unit. |
|
Backward Compatible |
|
""" |
|
|
|
if isinstance(unit, dict): |
|
unit = ControlNetUnit.from_dict(unit) |
|
|
|
return unit |
|
|