bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
from abc import ABC, abstractmethod
from typing import List, ClassVar, Dict, Optional, Set
from dataclasses import dataclass, field
from modules import shared
from scripts.logging import logger
from scripts.utils import ndarray_lru_cache
CACHE_SIZE = getattr(shared.cmd_opts, "controlnet_preprocessor_cache_size", 0)
@dataclass
class PreprocessorParameter:
"""
Class representing a parameter for a preprocessor.
Attributes:
label (str): The label for the parameter.
minimum (float): The minimum value of the parameter. Default is 0.0.
maximum (float): The maximum value of the parameter. Default is 1.0.
step (float): The step size for the parameter. Default is 0.01.
value (float): The initial value of the parameter. Default is 0.5.
visible (bool): Whether the parameter is visible or not. Default is False.
"""
label: str = "EMPTY_LABEL"
minimum: float = 0.0
maximum: float = 1.0
step: float = 0.01
value: float = 0.5
visible: bool = True
@property
def gradio_update_kwargs(self) -> dict:
return dict(
minimum=self.minimum,
maximum=self.maximum,
step=self.step,
label=self.label,
value=self.value,
visible=self.visible,
)
@property
def api_json(self) -> dict:
return dict(
name=self.label,
value=self.value,
min=self.minimum,
max=self.maximum,
step=self.step,
)
@dataclass
class Preprocessor(ABC):
"""
Class representing a preprocessor.
Attributes:
name (str): The name of the preprocessor.
tags (List[str]): The tags associated with the preprocessor.
slider_resolution (PreprocessorParameter): The parameter representing the resolution of the slider.
slider_1 (PreprocessorParameter): The first parameter of the slider.
slider_2 (PreprocessorParameter): The second parameter of the slider.
slider_3 (PreprocessorParameter): The third parameter of the slider.
show_control_mode (bool): Whether to show the control mode or not.
do_not_need_model (bool): Whether the preprocessor needs a model or not.
sorting_priority (int): The sorting priority of the preprocessor.
corp_image_with_a1111_mask_when_in_img2img_inpaint_tab (bool): Whether to crop the image with a1111 mask when in img2img inpaint tab or not.
fill_mask_with_one_when_resize_and_fill (bool): Whether to fill the mask with one when resizing and filling or not.
use_soft_projection_in_hr_fix (bool): Whether to use soft projection in hr fix or not.
expand_mask_when_resize_and_fill (bool): Whether to expand the mask when resizing and filling or not.
"""
name: str
_label: str = None
tags: List[str] = field(default_factory=list)
slider_resolution = PreprocessorParameter(
label="Resolution",
minimum=64,
maximum=2048,
value=512,
step=8,
visible=True,
)
slider_1 = PreprocessorParameter(visible=False)
slider_2 = PreprocessorParameter(visible=False)
slider_3 = PreprocessorParameter(visible=False)
returns_image: bool = True
show_control_mode = True
do_not_need_model = False
sorting_priority = 0 # higher goes to top in the list
corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True
fill_mask_with_one_when_resize_and_fill = False
use_soft_projection_in_hr_fix = False
expand_mask_when_resize_and_fill = False
all_processors: ClassVar[Dict[str, "Preprocessor"]] = {}
all_processors_by_name: ClassVar[Dict[str, "Preprocessor"]] = {}
@property
def label(self) -> str:
"""Display name on UI."""
return self._label if self._label is not None else self.name
@classmethod
def add_supported_preprocessor(cls, p: "Preprocessor"):
assert p.label not in cls.all_processors, f"{p.label} already registered!"
cls.all_processors[p.label] = p
assert p.name not in cls.all_processors_by_name, f"{p.name} already registered!"
cls.all_processors_by_name[p.name] = p
logger.debug(f"{p.name} registered. Total preprocessors ({len(cls.all_processors)}).")
@classmethod
def get_preprocessor(cls, name: str) -> Optional["Preprocessor"]:
return cls.all_processors.get(name, cls.all_processors_by_name.get(name, None))
@classmethod
def get_sorted_preprocessors(cls) -> List["Preprocessor"]:
preprocessors = [p for k, p in cls.all_processors.items() if k != "none"]
return [cls.all_processors["none"]] + sorted(
preprocessors,
key=lambda x: str(x.sorting_priority).zfill(8) + x.label,
reverse=True,
)
@classmethod
def get_all_preprocessor_tags(cls):
tags = set()
for _, p in cls.all_processors.items():
tags.update(set(p.tags))
return ["All"] + sorted(list(tags))
@classmethod
def get_filtered_preprocessors(cls, tag: str) -> List["Preprocessor"]:
if tag == "All":
return cls.all_processors
return [
p
for p in cls.get_sorted_preprocessors()
if tag in p.tags or p.label == "none"
]
@classmethod
def get_default_preprocessor(cls, tag: str) -> "Preprocessor":
ps = cls.get_filtered_preprocessors(tag)
assert len(ps) > 0
return ps[0] if len(ps) == 1 else ps[1]
@classmethod
def tag_to_filters(cls, tag: str) -> Set[str]:
filters_aliases = {
"instructp2p": ["ip2p"],
"segmentation": ["seg"],
"normalmap": ["normal"],
"t2i-adapter": ["t2i_adapter", "t2iadapter", "t2ia"],
"ip-adapter": ["ip_adapter", "ipadapter"],
"openpose": ["openpose", "densepose"],
"instant-id": ["instant_id", "instantid"],
"scribble": ["sketch"],
"tile": ["blur"],
}
tag = tag.lower()
return set([tag] + filters_aliases.get(tag, []))
@ndarray_lru_cache(max_size=CACHE_SIZE)
def cached_call(self, *args, **kwargs):
logger.debug(f"Calling preprocessor {self.name} outside of cache.")
return self(*args, **kwargs)
def __hash__(self):
return hash(self.name)
def __eq__(self, other):
return self.__hash__() == other.__hash__()
@abstractmethod
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
input_mask=None,
**kwargs,
):
pass