Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from typing import List, ClassVar, Dict, Optional, Set | |
from dataclasses import dataclass, field | |
from modules import shared | |
from scripts.logging import logger | |
from scripts.utils import ndarray_lru_cache | |
CACHE_SIZE = getattr(shared.cmd_opts, "controlnet_preprocessor_cache_size", 0) | |
class PreprocessorParameter: | |
""" | |
Class representing a parameter for a preprocessor. | |
Attributes: | |
label (str): The label for the parameter. | |
minimum (float): The minimum value of the parameter. Default is 0.0. | |
maximum (float): The maximum value of the parameter. Default is 1.0. | |
step (float): The step size for the parameter. Default is 0.01. | |
value (float): The initial value of the parameter. Default is 0.5. | |
visible (bool): Whether the parameter is visible or not. Default is False. | |
""" | |
label: str = "EMPTY_LABEL" | |
minimum: float = 0.0 | |
maximum: float = 1.0 | |
step: float = 0.01 | |
value: float = 0.5 | |
visible: bool = True | |
def gradio_update_kwargs(self) -> dict: | |
return dict( | |
minimum=self.minimum, | |
maximum=self.maximum, | |
step=self.step, | |
label=self.label, | |
value=self.value, | |
visible=self.visible, | |
) | |
def api_json(self) -> dict: | |
return dict( | |
name=self.label, | |
value=self.value, | |
min=self.minimum, | |
max=self.maximum, | |
step=self.step, | |
) | |
class Preprocessor(ABC): | |
""" | |
Class representing a preprocessor. | |
Attributes: | |
name (str): The name of the preprocessor. | |
tags (List[str]): The tags associated with the preprocessor. | |
slider_resolution (PreprocessorParameter): The parameter representing the resolution of the slider. | |
slider_1 (PreprocessorParameter): The first parameter of the slider. | |
slider_2 (PreprocessorParameter): The second parameter of the slider. | |
slider_3 (PreprocessorParameter): The third parameter of the slider. | |
show_control_mode (bool): Whether to show the control mode or not. | |
do_not_need_model (bool): Whether the preprocessor needs a model or not. | |
sorting_priority (int): The sorting priority of the preprocessor. | |
corp_image_with_a1111_mask_when_in_img2img_inpaint_tab (bool): Whether to crop the image with a1111 mask when in img2img inpaint tab or not. | |
fill_mask_with_one_when_resize_and_fill (bool): Whether to fill the mask with one when resizing and filling or not. | |
use_soft_projection_in_hr_fix (bool): Whether to use soft projection in hr fix or not. | |
expand_mask_when_resize_and_fill (bool): Whether to expand the mask when resizing and filling or not. | |
""" | |
name: str | |
_label: str = None | |
tags: List[str] = field(default_factory=list) | |
slider_resolution = PreprocessorParameter( | |
label="Resolution", | |
minimum=64, | |
maximum=2048, | |
value=512, | |
step=8, | |
visible=True, | |
) | |
slider_1 = PreprocessorParameter(visible=False) | |
slider_2 = PreprocessorParameter(visible=False) | |
slider_3 = PreprocessorParameter(visible=False) | |
returns_image: bool = True | |
show_control_mode = True | |
do_not_need_model = False | |
sorting_priority = 0 # higher goes to top in the list | |
corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True | |
fill_mask_with_one_when_resize_and_fill = False | |
use_soft_projection_in_hr_fix = False | |
expand_mask_when_resize_and_fill = False | |
all_processors: ClassVar[Dict[str, "Preprocessor"]] = {} | |
all_processors_by_name: ClassVar[Dict[str, "Preprocessor"]] = {} | |
def label(self) -> str: | |
"""Display name on UI.""" | |
return self._label if self._label is not None else self.name | |
def add_supported_preprocessor(cls, p: "Preprocessor"): | |
assert p.label not in cls.all_processors, f"{p.label} already registered!" | |
cls.all_processors[p.label] = p | |
assert p.name not in cls.all_processors_by_name, f"{p.name} already registered!" | |
cls.all_processors_by_name[p.name] = p | |
logger.debug(f"{p.name} registered. Total preprocessors ({len(cls.all_processors)}).") | |
def get_preprocessor(cls, name: str) -> Optional["Preprocessor"]: | |
return cls.all_processors.get(name, cls.all_processors_by_name.get(name, None)) | |
def get_sorted_preprocessors(cls) -> List["Preprocessor"]: | |
preprocessors = [p for k, p in cls.all_processors.items() if k != "none"] | |
return [cls.all_processors["none"]] + sorted( | |
preprocessors, | |
key=lambda x: str(x.sorting_priority).zfill(8) + x.label, | |
reverse=True, | |
) | |
def get_all_preprocessor_tags(cls): | |
tags = set() | |
for _, p in cls.all_processors.items(): | |
tags.update(set(p.tags)) | |
return ["All"] + sorted(list(tags)) | |
def get_filtered_preprocessors(cls, tag: str) -> List["Preprocessor"]: | |
if tag == "All": | |
return cls.all_processors | |
return [ | |
p | |
for p in cls.get_sorted_preprocessors() | |
if tag in p.tags or p.label == "none" | |
] | |
def get_default_preprocessor(cls, tag: str) -> "Preprocessor": | |
ps = cls.get_filtered_preprocessors(tag) | |
assert len(ps) > 0 | |
return ps[0] if len(ps) == 1 else ps[1] | |
def tag_to_filters(cls, tag: str) -> Set[str]: | |
filters_aliases = { | |
"instructp2p": ["ip2p"], | |
"segmentation": ["seg"], | |
"normalmap": ["normal"], | |
"t2i-adapter": ["t2i_adapter", "t2iadapter", "t2ia"], | |
"ip-adapter": ["ip_adapter", "ipadapter"], | |
"openpose": ["openpose", "densepose"], | |
"instant-id": ["instant_id", "instantid"], | |
"scribble": ["sketch"], | |
"tile": ["blur"], | |
} | |
tag = tag.lower() | |
return set([tag] + filters_aliases.get(tag, [])) | |
def cached_call(self, *args, **kwargs): | |
logger.debug(f"Calling preprocessor {self.name} outside of cache.") | |
return self(*args, **kwargs) | |
def __hash__(self): | |
return hash(self.name) | |
def __eq__(self, other): | |
return self.__hash__() == other.__hash__() | |
def __call__( | |
self, | |
input_image, | |
resolution, | |
slider_1=None, | |
slider_2=None, | |
slider_3=None, | |
input_mask=None, | |
**kwargs, | |
): | |
pass | |