File size: 6,708 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from abc import ABC, abstractmethod
from typing import List, ClassVar, Dict, Optional, Set
from dataclasses import dataclass, field

from modules import shared
from scripts.logging import logger
from scripts.utils import ndarray_lru_cache


CACHE_SIZE = getattr(shared.cmd_opts, "controlnet_preprocessor_cache_size", 0)


@dataclass
class PreprocessorParameter:
    """
    Class representing a parameter for a preprocessor.

    Attributes:
        label (str): The label for the parameter.
        minimum (float): The minimum value of the parameter. Default is 0.0.
        maximum (float): The maximum value of the parameter. Default is 1.0.
        step (float): The step size for the parameter. Default is 0.01.
        value (float): The initial value of the parameter. Default is 0.5.
        visible (bool): Whether the parameter is visible or not. Default is False.
    """

    label: str = "EMPTY_LABEL"
    minimum: float = 0.0
    maximum: float = 1.0
    step: float = 0.01
    value: float = 0.5
    visible: bool = True

    @property
    def gradio_update_kwargs(self) -> dict:
        return dict(
            minimum=self.minimum,
            maximum=self.maximum,
            step=self.step,
            label=self.label,
            value=self.value,
            visible=self.visible,
        )

    @property
    def api_json(self) -> dict:
        return dict(
            name=self.label,
            value=self.value,
            min=self.minimum,
            max=self.maximum,
            step=self.step,
        )


@dataclass
class Preprocessor(ABC):
    """
    Class representing a preprocessor.

    Attributes:
        name (str): The name of the preprocessor.
        tags (List[str]): The tags associated with the preprocessor.
        slider_resolution (PreprocessorParameter): The parameter representing the resolution of the slider.
        slider_1 (PreprocessorParameter): The first parameter of the slider.
        slider_2 (PreprocessorParameter): The second parameter of the slider.
        slider_3 (PreprocessorParameter): The third parameter of the slider.
        show_control_mode (bool): Whether to show the control mode or not.
        do_not_need_model (bool): Whether the preprocessor needs a model or not.
        sorting_priority (int): The sorting priority of the preprocessor.
        corp_image_with_a1111_mask_when_in_img2img_inpaint_tab (bool): Whether to crop the image with a1111 mask when in img2img inpaint tab or not.
        fill_mask_with_one_when_resize_and_fill (bool): Whether to fill the mask with one when resizing and filling or not.
        use_soft_projection_in_hr_fix (bool): Whether to use soft projection in hr fix or not.
        expand_mask_when_resize_and_fill (bool): Whether to expand the mask when resizing and filling or not.
    """

    name: str
    _label: str = None
    tags: List[str] = field(default_factory=list)
    slider_resolution = PreprocessorParameter(
        label="Resolution",
        minimum=64,
        maximum=2048,
        value=512,
        step=8,
        visible=True,
    )
    slider_1 = PreprocessorParameter(visible=False)
    slider_2 = PreprocessorParameter(visible=False)
    slider_3 = PreprocessorParameter(visible=False)
    returns_image: bool = True
    show_control_mode = True
    do_not_need_model = False
    sorting_priority = 0  # higher goes to top in the list
    corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True
    fill_mask_with_one_when_resize_and_fill = False
    use_soft_projection_in_hr_fix = False
    expand_mask_when_resize_and_fill = False

    all_processors: ClassVar[Dict[str, "Preprocessor"]] = {}
    all_processors_by_name: ClassVar[Dict[str, "Preprocessor"]] = {}

    @property
    def label(self) -> str:
        """Display name on UI."""
        return self._label if self._label is not None else self.name

    @classmethod
    def add_supported_preprocessor(cls, p: "Preprocessor"):
        assert p.label not in cls.all_processors, f"{p.label} already registered!"
        cls.all_processors[p.label] = p
        assert p.name not in cls.all_processors_by_name, f"{p.name} already registered!"
        cls.all_processors_by_name[p.name] = p
        logger.debug(f"{p.name} registered. Total preprocessors ({len(cls.all_processors)}).")

    @classmethod
    def get_preprocessor(cls, name: str) -> Optional["Preprocessor"]:
        return cls.all_processors.get(name, cls.all_processors_by_name.get(name, None))

    @classmethod
    def get_sorted_preprocessors(cls) -> List["Preprocessor"]:
        preprocessors = [p for k, p in cls.all_processors.items() if k != "none"]
        return [cls.all_processors["none"]] + sorted(
            preprocessors,
            key=lambda x: str(x.sorting_priority).zfill(8) + x.label,
            reverse=True,
        )

    @classmethod
    def get_all_preprocessor_tags(cls):
        tags = set()
        for _, p in cls.all_processors.items():
            tags.update(set(p.tags))
        return ["All"] + sorted(list(tags))

    @classmethod
    def get_filtered_preprocessors(cls, tag: str) -> List["Preprocessor"]:
        if tag == "All":
            return cls.all_processors
        return [
            p
            for p in cls.get_sorted_preprocessors()
            if tag in p.tags or p.label == "none"
        ]

    @classmethod
    def get_default_preprocessor(cls, tag: str) -> "Preprocessor":
        ps = cls.get_filtered_preprocessors(tag)
        assert len(ps) > 0
        return ps[0] if len(ps) == 1 else ps[1]

    @classmethod
    def tag_to_filters(cls, tag: str) -> Set[str]:
        filters_aliases = {
            "instructp2p": ["ip2p"],
            "segmentation": ["seg"],
            "normalmap": ["normal"],
            "t2i-adapter": ["t2i_adapter", "t2iadapter", "t2ia"],
            "ip-adapter": ["ip_adapter", "ipadapter"],
            "openpose": ["openpose", "densepose"],
            "instant-id": ["instant_id", "instantid"],
            "scribble": ["sketch"],
            "tile": ["blur"],
        }

        tag = tag.lower()
        return set([tag] + filters_aliases.get(tag, []))

    @ndarray_lru_cache(max_size=CACHE_SIZE)
    def cached_call(self, *args, **kwargs):
        logger.debug(f"Calling preprocessor {self.name} outside of cache.")
        return self(*args, **kwargs)

    def __hash__(self):
        return hash(self.name)

    def __eq__(self, other):
        return self.__hash__() == other.__hash__()

    @abstractmethod
    def __call__(
        self,
        input_image,
        resolution,
        slider_1=None,
        slider_2=None,
        slider_3=None,
        input_mask=None,
        **kwargs,
    ):
        pass