import json import gradio as gr import functools from copy import copy from typing import List, Optional, Union, Dict, Tuple, Literal from dataclasses import dataclass import numpy as np from scripts.supported_preprocessor import Preprocessor from scripts.utils import svg_preprocess, read_image from scripts import ( global_state, external_code, ) from annotator.util import HWC3 from scripts.logging import logger from scripts.controlnet_ui.openpose_editor import OpenposeEditor from scripts.controlnet_ui.preset import ControlNetPresetUI from scripts.controlnet_ui.tool_button import ToolButton from scripts.controlnet_ui.photopea import Photopea from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl from scripts.enums import InputMode from modules import shared from modules.ui_components import FormRow @dataclass class A1111Context: """Contains all components from A1111.""" img2img_batch_input_dir: Optional[gr.components.Component] = None img2img_batch_output_dir: Optional[gr.components.Component] = None txt2img_submit_button: Optional[gr.components.Component] = None img2img_submit_button: Optional[gr.components.Component] = None # Slider controls from A1111 WebUI. txt2img_w_slider: Optional[gr.components.Component] = None txt2img_h_slider: Optional[gr.components.Component] = None img2img_w_slider: Optional[gr.components.Component] = None img2img_h_slider: Optional[gr.components.Component] = None img2img_img2img_tab: Optional[gr.components.Component] = None img2img_img2img_sketch_tab: Optional[gr.components.Component] = None img2img_batch_tab: Optional[gr.components.Component] = None img2img_inpaint_tab: Optional[gr.components.Component] = None img2img_inpaint_sketch_tab: Optional[gr.components.Component] = None img2img_inpaint_upload_tab: Optional[gr.components.Component] = None img2img_inpaint_area: Optional[gr.components.Component] = None # txt2img_enable_hr is only available for A1111 > 1.7.0. txt2img_enable_hr: Optional[gr.components.Component] = None setting_sd_model_checkpoint: Optional[gr.components.Component] = None @property def img2img_inpaint_tabs(self) -> Tuple[gr.components.Component]: return ( self.img2img_inpaint_tab, self.img2img_inpaint_sketch_tab, self.img2img_inpaint_upload_tab, ) @property def img2img_non_inpaint_tabs(self) -> List[gr.components.Component]: return ( self.img2img_img2img_tab, self.img2img_img2img_sketch_tab, self.img2img_batch_tab, ) @property def ui_initialized(self) -> bool: optional_components = { # Optional components are only available after A1111 v1.7.0. "img2img_img2img_tab": "img2img_img2img_tab", "img2img_img2img_sketch_tab": "img2img_img2img_sketch_tab", "img2img_batch_tab": "img2img_batch_tab", "img2img_inpaint_tab": "img2img_inpaint_tab", "img2img_inpaint_sketch_tab": "img2img_inpaint_sketch_tab", "img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab", # SDNext does not have this field. Temporarily disable the callback on # the checkpoint change until we find a way to register an event when # all A1111 UI components are ready. "setting_sd_model_checkpoint": "setting_sd_model_checkpoint", } return all( c for name, c in vars(self).items() if name not in optional_components.values() ) def set_component(self, component: gr.components.Component): id_mapping = { "img2img_batch_input_dir": "img2img_batch_input_dir", "img2img_batch_output_dir": "img2img_batch_output_dir", "txt2img_generate": "txt2img_submit_button", "img2img_generate": "img2img_submit_button", "txt2img_width": "txt2img_w_slider", "txt2img_height": "txt2img_h_slider", "img2img_width": "img2img_w_slider", "img2img_height": "img2img_h_slider", "img2img_img2img_tab": "img2img_img2img_tab", "img2img_img2img_sketch_tab": "img2img_img2img_sketch_tab", "img2img_batch_tab": "img2img_batch_tab", "img2img_inpaint_tab": "img2img_inpaint_tab", "img2img_inpaint_sketch_tab": "img2img_inpaint_sketch_tab", "img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab", "img2img_inpaint_full_res": "img2img_inpaint_area", "txt2img_hr-checkbox": "txt2img_enable_hr", # backward compatibility for webui < 1.6.0 "txt2img_enable_hr": "txt2img_enable_hr", # setting_sd_model_checkpoint is expected to be initialized last. # "setting_sd_model_checkpoint": "setting_sd_model_checkpoint", } elem_id = getattr(component, "elem_id", None) # Do not set component if it has already been set. # https://github.com/Mikubill/sd-webui-controlnet/issues/2587 if elem_id in id_mapping and getattr(self, id_mapping[elem_id]) is None: setattr(self, id_mapping[elem_id], component) logger.debug(f"Setting {elem_id}.") logger.debug( f"A1111 initialized {sum(c is not None for c in vars(self).values())}/{len(vars(self).keys())}." ) class UiControlNetUnit(external_code.ControlNetUnit): """The data class that stores all states of a ControlNetUnit.""" def __init__( self, input_mode: InputMode = InputMode.SIMPLE, batch_images: Optional[Union[str, List[external_code.InputImage]]] = None, output_dir: str = "", loopback: bool = False, merge_gallery_files: List[ Dict[Union[Literal["name"], Literal["data"]], str] ] = [], use_preview_as_input: bool = False, generated_image: Optional[np.ndarray] = None, mask_image: Optional[np.ndarray] = None, enabled: bool = True, module: Optional[str] = None, model: Optional[str] = None, weight: float = 1.0, image: Optional[Dict[str, np.ndarray]] = None, *args, **kwargs, ): if use_preview_as_input and generated_image is not None: input_image = generated_image module = "none" else: input_image = image # Prefer uploaded mask_image over hand-drawn mask. if input_image is not None and mask_image is not None: assert isinstance(input_image, dict) input_image["mask"] = mask_image if merge_gallery_files and input_mode == InputMode.MERGE: input_image = [ {"image": read_image(file["name"])} for file in merge_gallery_files ] super().__init__(enabled, module, model, weight, input_image, *args, **kwargs) self.is_ui = True self.input_mode = input_mode self.batch_images = batch_images self.output_dir = output_dir self.loopback = loopback def unfold_merged(self) -> List[external_code.ControlNetUnit]: """Unfolds a merged unit to multiple units. Keeps the unit merged for preprocessors that can accept multiple input images. """ if self.input_mode != InputMode.MERGE: return [copy(self)] if self.accepts_multiple_inputs(): self.input_mode = InputMode.SIMPLE return [copy(self)] assert isinstance(self.image, list) result = [] for image in self.image: unit = copy(self) unit.image = image["image"] unit.input_mode = InputMode.SIMPLE unit.weight = self.weight / len(self.image) result.append(unit) return result class ControlNetUiGroup(object): refresh_symbol = "\U0001f504" # 🔄 switch_values_symbol = "\U000021C5" # ⇅ camera_symbol = "\U0001F4F7" # 📷 reverse_symbol = "\U000021C4" # ⇄ tossup_symbol = "\u2934" trigger_symbol = "\U0001F4A5" # 💥 open_symbol = "\U0001F4DD" # 📝 tooltips = { "🔄": "Refresh", "\u2934": "Send dimensions to stable diffusion", "💥": "Run preprocessor", "📝": "Open new canvas", "📷": "Enable webcam", "⇄": "Mirror webcam", } global_batch_input_dir = gr.Textbox( label="Controlnet input directory", placeholder="Leave empty to use input directory", **shared.hide_dirs, elem_id="controlnet_batch_input_dir", ) a1111_context = A1111Context() # All ControlNetUiGroup instances created. all_ui_groups: List["ControlNetUiGroup"] = [] def __init__( self, is_img2img: bool, default_unit: external_code.ControlNetUnit, photopea: Optional[Photopea], ): # Whether callbacks have been registered. self.callbacks_registered: bool = False # Whether the render method on this object has been called. self.ui_initialized: bool = False self.is_img2img = is_img2img self.default_unit = default_unit self.photopea = photopea self.webcam_enabled = False self.webcam_mirrored = False # Note: All gradio elements declared in `render` will be defined as member variable. # Update counter to trigger a force update of UiControlNetUnit. # This is useful when a field with no event subscriber available changes. # e.g. gr.Gallery, gr.State, etc. self.update_unit_counter = None self.upload_tab = None self.image = None self.generated_image_group = None self.generated_image = None self.mask_image_group = None self.mask_image = None self.batch_tab = None self.batch_image_dir = None self.merge_tab = None self.merge_gallery = None self.merge_upload_button = None self.merge_clear_button = None self.create_canvas = None self.canvas_width = None self.canvas_height = None self.canvas_create_button = None self.canvas_cancel_button = None self.open_new_canvas_button = None self.webcam_enable = None self.webcam_mirror = None self.send_dimen_button = None self.enabled = None self.low_vram = None self.pixel_perfect = None self.preprocessor_preview = None self.mask_upload = None self.type_filter = None self.module = None self.trigger_preprocessor = None self.model = None self.refresh_models = None self.weight = None self.guidance_start = None self.guidance_end = None self.advanced = None self.processor_res = None self.threshold_a = None self.threshold_b = None self.control_mode = None self.resize_mode = None self.loopback = None self.use_preview_as_input = None self.openpose_editor = None self.preset_panel = None self.upload_independent_img_in_img2img = None self.image_upload_panel = None self.save_detected_map = None self.input_mode = gr.State(InputMode.SIMPLE) self.inpaint_crop_input_image = None self.hr_option = None self.advanced_weight_control = AdvancedWeightControl() self.batch_image_dir_state = None self.output_dir_state = None # API-only fields self.advanced_weighting = gr.State(None) self.ipadapter_input = gr.State(None) ControlNetUiGroup.all_ui_groups.append(self) def render(self, tabname: str, elem_id_tabname: str) -> None: """The pure HTML structure of a single ControlNetUnit. Calling this function will populate `self` with all gradio element declared in local scope. Args: tabname: elem_id_tabname: Returns: None """ self.update_unit_counter = gr.Number(value=0, visible=False) self.openpose_editor = OpenposeEditor() with gr.Group(visible=not self.is_img2img) as self.image_upload_panel: self.save_detected_map = gr.Checkbox(value=True, visible=False) with gr.Tabs(): with gr.Tab(label="Single Image") as self.upload_tab: with gr.Row(elem_classes=["cnet-image-row"], equal_height=True): with gr.Group(elem_classes=["cnet-input-image-group"]): self.image = gr.Image( source="upload", brush_radius=20, mirror_webcam=False, type="numpy", tool="sketch", elem_id=f"{elem_id_tabname}_{tabname}_input_image", elem_classes=["cnet-image"], brush_color=shared.opts.img2img_inpaint_mask_brush_color if hasattr( shared.opts, "img2img_inpaint_mask_brush_color" ) else None, ) self.image.preprocess = functools.partial( svg_preprocess, preprocess=self.image.preprocess ) self.openpose_editor.render_upload() with gr.Group( visible=False, elem_classes=["cnet-generated-image-group"] ) as self.generated_image_group: self.generated_image = gr.Image( value=None, label="Preprocessor Preview", elem_id=f"{elem_id_tabname}_{tabname}_generated_image", elem_classes=["cnet-image"], interactive=True, height=242, ) # Gradio's magic number. Only 242 works. with gr.Group( elem_classes=["cnet-generated-image-control-group"] ): if self.photopea: self.photopea.render_child_trigger() self.openpose_editor.render_edit() preview_check_elem_id = f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_preview_checkbox" preview_close_button_js = f"document.querySelector('#{preview_check_elem_id} input[type=\\'checkbox\\']').click();" gr.HTML( value=f"""Close""", visible=True, elem_classes=["cnet-close-preview"], ) with gr.Group( visible=False, elem_classes=["cnet-mask-image-group"] ) as self.mask_image_group: self.mask_image = gr.Image( value=None, label="Upload Mask", elem_id=f"{elem_id_tabname}_{tabname}_mask_image", elem_classes=["cnet-mask-image"], interactive=True, ) with gr.Tab(label="Batch") as self.batch_tab: self.batch_image_dir = gr.Textbox( label="Input Directory", placeholder="Leave empty to use img2img batch controlnet input directory", elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir", ) with gr.Tab(label="Multi-Inputs") as self.merge_tab: self.merge_gallery = gr.Gallery( columns=[4], rows=[2], object_fit="contain", height="auto" ) with gr.Row(): self.merge_upload_button = gr.UploadButton( "Upload Images", file_types=["image"], file_count="multiple", ) self.merge_clear_button = gr.Button("Clear Images") if self.photopea: self.photopea.attach_photopea_output(self.generated_image) with gr.Accordion( label="Open New Canvas", visible=False ) as self.create_canvas: self.canvas_width = gr.Slider( label="New Canvas Width", minimum=256, maximum=1024, value=512, step=64, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_width", ) self.canvas_height = gr.Slider( label="New Canvas Height", minimum=256, maximum=1024, value=512, step=64, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_height", ) with gr.Row(): self.canvas_create_button = gr.Button( value="Create New Canvas", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_create_button", ) self.canvas_cancel_button = gr.Button( value="Cancel", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_cancel_button", ) with gr.Row(elem_classes="controlnet_image_controls"): gr.HTML( value="

Set the preprocessor to [invert] If your image has white background and black lines.

", elem_classes="controlnet_invert_warning", ) self.open_new_canvas_button = ToolButton( value=ControlNetUiGroup.open_symbol, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_open_new_canvas_button", tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.open_symbol], ) self.webcam_enable = ToolButton( value=ControlNetUiGroup.camera_symbol, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_enable", tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.camera_symbol], ) self.webcam_mirror = ToolButton( value=ControlNetUiGroup.reverse_symbol, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_mirror", tooltip=ControlNetUiGroup.tooltips[ ControlNetUiGroup.reverse_symbol ], ) self.send_dimen_button = ToolButton( value=ControlNetUiGroup.tossup_symbol, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_send_dimen_button", tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.tossup_symbol], ) with FormRow(elem_classes=["controlnet_main_options"]): self.enabled = gr.Checkbox( label="Enable", value=self.default_unit.enabled, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_enable_checkbox", elem_classes=["cnet-unit-enabled"], ) self.low_vram = gr.Checkbox( label="Low VRAM", value=self.default_unit.low_vram, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_low_vram_checkbox", ) self.pixel_perfect = gr.Checkbox( label="Pixel Perfect", value=self.default_unit.pixel_perfect, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pixel_perfect_checkbox", ) self.preprocessor_preview = gr.Checkbox( label="Allow Preview", value=False, elem_classes=["cnet-allow-preview"], elem_id=preview_check_elem_id, visible=not self.is_img2img, ) self.mask_upload = gr.Checkbox( label="Mask Upload", value=False, elem_classes=["cnet-mask-upload"], elem_id=f"{elem_id_tabname}_{tabname}_controlnet_mask_upload_checkbox", visible=not self.is_img2img, ) self.use_preview_as_input = gr.Checkbox( label="Preview as Input", value=False, elem_classes=["cnet-preview-as-input"], visible=False, ) with gr.Row(elem_classes="controlnet_img2img_options"): if self.is_img2img: self.upload_independent_img_in_img2img = gr.Checkbox( label="Upload independent control image", value=False, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_same_img2img_checkbox", elem_classes=["cnet-unit-same_img2img"], ) else: self.upload_independent_img_in_img2img = None # Note: The checkbox needs to exist for both img2img and txt2img as infotext # needs the checkbox value. self.inpaint_crop_input_image = gr.Checkbox( label="Crop input image based on A1111 mask", value=False, elem_classes=["cnet-crop-input-image"], visible=False, ) with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]): self.type_filter = gr.Radio( Preprocessor.get_all_preprocessor_tags(), label="Control Type", value="All", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_type_filter_radio", elem_classes="controlnet_control_type_filter_group", ) with gr.Row(elem_classes=["controlnet_preprocessor_model", "controlnet_row"]): self.module = gr.Dropdown( [p.label for p in Preprocessor.get_sorted_preprocessors()], label="Preprocessor", value=self.default_unit.module, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_dropdown", ) self.trigger_preprocessor = ToolButton( value=ControlNetUiGroup.trigger_symbol, visible=not self.is_img2img, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_trigger_preprocessor", elem_classes=["cnet-run-preprocessor"], tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.trigger_symbol], ) self.model = gr.Dropdown( list(global_state.cn_models.keys()), label="Model", value=self.default_unit.model, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_model_dropdown", ) self.refresh_models = ToolButton( value=ControlNetUiGroup.refresh_symbol, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_refresh_models", tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.refresh_symbol], ) with gr.Row(elem_classes=["controlnet_weight_steps", "controlnet_row"]): self.weight = gr.Slider( label="Control Weight", value=self.default_unit.weight, minimum=0.0, maximum=2.0, step=0.05, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_weight_slider", elem_classes="controlnet_control_weight_slider", ) self.guidance_start = gr.Slider( label="Starting Control Step", value=self.default_unit.guidance_start, minimum=0.0, maximum=1.0, interactive=True, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_start_control_step_slider", elem_classes="controlnet_start_control_step_slider", ) self.guidance_end = gr.Slider( label="Ending Control Step", value=self.default_unit.guidance_end, minimum=0.0, maximum=1.0, interactive=True, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_ending_control_step_slider", elem_classes="controlnet_ending_control_step_slider", ) # advanced options with gr.Column(visible=False) as self.advanced: self.processor_res = gr.Slider( label="Preprocessor resolution", value=self.default_unit.processor_res, minimum=64, maximum=2048, visible=False, interactive=True, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_resolution_slider", ) self.threshold_a = gr.Slider( label="Threshold A", value=self.default_unit.threshold_a, minimum=64, maximum=1024, visible=False, interactive=True, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_threshold_A_slider", ) self.threshold_b = gr.Slider( label="Threshold B", value=self.default_unit.threshold_b, minimum=64, maximum=1024, visible=False, interactive=True, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_threshold_B_slider", ) self.control_mode = gr.Radio( choices=[e.value for e in external_code.ControlMode], value=self.default_unit.control_mode.value, label="Control Mode", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_mode_radio", elem_classes="controlnet_control_mode_radio", ) self.resize_mode = gr.Radio( choices=[e.value for e in external_code.ResizeMode], value=self.default_unit.resize_mode.value, label="Resize Mode", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_resize_mode_radio", elem_classes="controlnet_resize_mode_radio", visible=not self.is_img2img, ) self.hr_option = gr.Radio( choices=[e.value for e in external_code.HiResFixOption], value=self.default_unit.hr_option.value, label="Hires-Fix Option", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio", elem_classes="controlnet_hr_option_radio", visible=False, ) self.loopback = gr.Checkbox( label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation", value=self.default_unit.loopback, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox", elem_classes="controlnet_loopback_checkbox", visible=False, ) self.advanced_weight_control.render() self.preset_panel = ControlNetPresetUI( id_prefix=f"{elem_id_tabname}_{tabname}_" ) self.batch_image_dir_state = gr.State("") self.output_dir_state = gr.State("") unit_args = ( self.input_mode, self.batch_image_dir_state, self.output_dir_state, self.loopback, # Non-persistent fields. # Following inputs will not be persistent on `ControlNetUnit`. # They are only used during object construction. self.merge_gallery, self.use_preview_as_input, self.generated_image, self.mask_image, # End of Non-persistent fields. self.enabled, self.module, self.model, self.weight, self.image, self.resize_mode, self.low_vram, self.processor_res, self.threshold_a, self.threshold_b, self.guidance_start, self.guidance_end, self.pixel_perfect, self.control_mode, self.inpaint_crop_input_image, self.hr_option, self.save_detected_map, self.advanced_weighting, ) unit = gr.State(self.default_unit) for comp in unit_args + (self.update_unit_counter,): event_subscribers = [] if hasattr(comp, "edit"): event_subscribers.append(comp.edit) elif hasattr(comp, "click"): event_subscribers.append(comp.click) elif isinstance(comp, gr.Slider) and hasattr(comp, "release"): event_subscribers.append(comp.release) elif hasattr(comp, "change"): event_subscribers.append(comp.change) if hasattr(comp, "clear"): event_subscribers.append(comp.clear) for event_subscriber in event_subscribers: event_subscriber( fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit ) ( ControlNetUiGroup.a1111_context.img2img_submit_button if self.is_img2img else ControlNetUiGroup.a1111_context.txt2img_submit_button ).click( fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit, queue=False, ) self.register_core_callbacks() self.ui_initialized = True return unit def register_send_dimensions(self): """Register event handler for send dimension button.""" def send_dimensions(image): def closesteight(num): rem = num % 8 if rem <= 4: return round(num - rem) else: return round(num + (8 - rem)) if image: interm = np.asarray(image.get("image")) return closesteight(interm.shape[1]), closesteight(interm.shape[0]) else: return gr.Slider.update(), gr.Slider.update() outputs = ( [ ControlNetUiGroup.a1111_context.img2img_w_slider, ControlNetUiGroup.a1111_context.img2img_h_slider, ] if self.is_img2img else [ ControlNetUiGroup.a1111_context.txt2img_w_slider, ControlNetUiGroup.a1111_context.txt2img_h_slider, ] ) self.send_dimen_button.click( fn=send_dimensions, inputs=[self.image], outputs=outputs, show_progress=False, ) def register_webcam_toggle(self): def webcam_toggle(): self.webcam_enabled = not self.webcam_enabled return { "value": None, "source": "webcam" if self.webcam_enabled else "upload", "__type__": "update", } self.webcam_enable.click( webcam_toggle, inputs=None, outputs=self.image, show_progress=False ) def register_webcam_mirror_toggle(self): def webcam_mirror_toggle(): self.webcam_mirrored = not self.webcam_mirrored return {"mirror_webcam": self.webcam_mirrored, "__type__": "update"} self.webcam_mirror.click( webcam_mirror_toggle, inputs=None, outputs=self.image, show_progress=False ) def register_refresh_all_models(self): def refresh_all_models(model: str): global_state.update_cn_models() choices = list(global_state.cn_models.keys()) return gr.Dropdown.update( value=model if model in global_state.cn_models else "None", choices=choices, ) self.refresh_models.click( refresh_all_models, inputs=[self.model], outputs=[self.model], show_progress=False, ) def register_build_sliders(self): def build_sliders(module: str, pp: bool): preprocessor = Preprocessor.get_preprocessor(module) slider_resolution_kwargs = preprocessor.slider_resolution.gradio_update_kwargs.copy() if pp: slider_resolution_kwargs['visible'] = False grs = [ gr.update(**slider_resolution_kwargs), gr.update(**preprocessor.slider_1.gradio_update_kwargs.copy()), gr.update(**preprocessor.slider_2.gradio_update_kwargs.copy()), gr.update(visible=True), gr.update(visible=not preprocessor.do_not_need_model), gr.update(visible=not preprocessor.do_not_need_model), gr.update(visible=preprocessor.show_control_mode), ] return grs inputs = [ self.module, self.pixel_perfect, ] outputs = [ self.processor_res, self.threshold_a, self.threshold_b, self.advanced, self.model, self.refresh_models, self.control_mode, ] self.module.change( build_sliders, inputs=inputs, outputs=outputs, show_progress=False ) self.pixel_perfect.change( build_sliders, inputs=inputs, outputs=outputs, show_progress=False ) def filter_selected(k: str): logger.debug(f"Switch to control type {k}") ( filtered_preprocessor_list, filtered_model_list, default_option, default_model, ) = global_state.select_control_type(k, global_state.get_sd_version()) return [ gr.Dropdown.update( value=default_option, choices=filtered_preprocessor_list ), gr.Dropdown.update( value=default_model, choices=filtered_model_list ), ] self.type_filter.change( fn=filter_selected, inputs=[self.type_filter], outputs=[self.module, self.model], show_progress=False, ) def register_sd_version_changed(self): def sd_version_changed(type_filter: str, current_model: str): """When SD version changes, update model dropdown choices.""" ( filtered_preprocessor_list, filtered_model_list, default_option, default_model, ) = global_state.select_control_type( type_filter, global_state.get_sd_version() ) if current_model in filtered_model_list: return gr.update() return gr.Dropdown.update( value=default_model, choices=filtered_model_list, ) if ControlNetUiGroup.a1111_context.setting_sd_model_checkpoint: ControlNetUiGroup.a1111_context.setting_sd_model_checkpoint.change( fn=sd_version_changed, inputs=[self.type_filter, self.model], outputs=[self.model], show_progress=False, ) def register_run_annotator(self): def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm, model: str): if image is None: return ( gr.update(value=None, visible=True), gr.update(), *self.openpose_editor.update(""), ) img = HWC3(image["image"]) has_mask = not ( (image["mask"][:, :, 0] <= 5).all() or (image["mask"][:, :, 0] >= 250).all() ) if "inpaint" in module: color = HWC3(image["image"]) alpha = image["mask"][:, :, 0:1] img = np.concatenate([color, alpha], axis=2) elif has_mask and not shared.opts.data.get( "controlnet_ignore_noninpaint_mask", False ): img = HWC3(image["mask"][:, :, 0]) preprocessor = Preprocessor.get_preprocessor(module) if pp: pres = external_code.pixel_perfect_resolution( img, target_H=t2i_h, target_W=t2i_w, resize_mode=external_code.resize_mode_from_value(rm), ) class JsonAcceptor: def __init__(self) -> None: self.value = "" def accept(self, json_dict: dict) -> None: self.value = json.dumps(json_dict) json_acceptor = JsonAcceptor() logger.info(f"Preview Resolution = {pres}") def is_openpose(module: str): return "openpose" in module # Only openpose preprocessor returns a JSON output, pass json_acceptor # only when a JSON output is expected. This will make preprocessor cache # work for all other preprocessors other than openpose ones. JSON acceptor # instance are different every call, which means cache will never take # effect. # TODO: Maybe we should let `preprocessor` return a Dict to alleviate this issue? # This requires changing all callsites though. result = preprocessor.cached_call( img, resolution=pres, slider_1=pthr_a, slider_2=pthr_b, low_vram=( ("clip" in module or module == "ip-adapter_face_id_plus") and shared.opts.data.get("controlnet_clip_detector_on_cpu", False) ), json_pose_callback=( json_acceptor.accept if is_openpose(module) else None ), model=model, ) if not preprocessor.returns_image: result = img result = external_code.visualize_inpaint_mask(result) return ( # Update to `generated_image` gr.update(value=result, visible=True, interactive=False), # preprocessor_preview gr.update(value=True), # openpose editor *self.openpose_editor.update(json_acceptor.value), ) self.trigger_preprocessor.click( fn=run_annotator, inputs=[ self.image, self.module, self.processor_res, self.threshold_a, self.threshold_b, ControlNetUiGroup.a1111_context.img2img_w_slider if self.is_img2img else ControlNetUiGroup.a1111_context.txt2img_w_slider, ControlNetUiGroup.a1111_context.img2img_h_slider if self.is_img2img else ControlNetUiGroup.a1111_context.txt2img_h_slider, self.pixel_perfect, self.resize_mode, self.model, ], outputs=[ self.generated_image, self.preprocessor_preview, *self.openpose_editor.outputs(), ], ) def register_shift_preview(self): def shift_preview(is_on): return ( # generated_image gr.update() if is_on else gr.update(value=None), # generated_image_group gr.update(visible=is_on), # use_preview_as_input, gr.update(visible=False), # Now this is automatically managed # download_pose_link gr.update() if is_on else gr.update(value=None), # modal edit button gr.update() if is_on else gr.update(visible=False), ) self.preprocessor_preview.change( fn=shift_preview, inputs=[self.preprocessor_preview], outputs=[ self.generated_image, self.generated_image_group, self.use_preview_as_input, self.openpose_editor.download_link, self.openpose_editor.modal, ], show_progress=False, ) def register_create_canvas(self): self.open_new_canvas_button.click( lambda: gr.Accordion.update(visible=True), inputs=None, outputs=self.create_canvas, show_progress=False, ) self.canvas_cancel_button.click( lambda: gr.Accordion.update(visible=False), inputs=None, outputs=self.create_canvas, show_progress=False, ) def fn_canvas(h, w): return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255, gr.Accordion.update( visible=False ) self.canvas_create_button.click( fn=fn_canvas, inputs=[self.canvas_height, self.canvas_width], outputs=[self.image, self.create_canvas], show_progress=False, ) def register_img2img_same_input(self): def fn_same_checked(x): return [ gr.update(value=None), gr.update(value=None), gr.update(value=False, visible=x), ] + [gr.update(visible=x)] * 4 self.upload_independent_img_in_img2img.change( fn_same_checked, inputs=self.upload_independent_img_in_img2img, outputs=[ self.image, self.batch_image_dir, self.preprocessor_preview, self.image_upload_panel, self.trigger_preprocessor, self.loopback, self.resize_mode, ], show_progress=False, ) def register_shift_crop_input_image(self): # A1111 < 1.7.0 compatibility. if any(c is None for c in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs): self.inpaint_crop_input_image.visible = True self.inpaint_crop_input_image.value = True return is_inpaint_tab = gr.State(False) def shift_crop_input_image(is_inpaint: bool, inpaint_area: int): # Note: inpaint_area (0: Whole picture, 1: Only masked) # By default set value to True, as most preprocessors need cropped result. return gr.update(value=True, visible=is_inpaint and inpaint_area == 1) gradio_kwargs = dict( fn=shift_crop_input_image, inputs=[ is_inpaint_tab, ControlNetUiGroup.a1111_context.img2img_inpaint_area, ], outputs=[self.inpaint_crop_input_image], show_progress=False, ) for elem in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs: elem.select(fn=lambda: True, inputs=[], outputs=[is_inpaint_tab]).then( **gradio_kwargs ) for elem in ControlNetUiGroup.a1111_context.img2img_non_inpaint_tabs: elem.select(fn=lambda: False, inputs=[], outputs=[is_inpaint_tab]).then( **gradio_kwargs ) ControlNetUiGroup.a1111_context.img2img_inpaint_area.change(**gradio_kwargs) def register_shift_hr_options(self): # A1111 version < 1.6.0. if not ControlNetUiGroup.a1111_context.txt2img_enable_hr: return ControlNetUiGroup.a1111_context.txt2img_enable_hr.change( fn=lambda checked: gr.update(visible=checked), inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr], outputs=[self.hr_option], show_progress=False, ) def register_shift_upload_mask(self): """Controls whether the upload mask input should be visible.""" self.mask_upload.change( fn=lambda checked: ( # Clear mask_image if unchecked. (gr.update(visible=False), gr.update(value=None)) if not checked else (gr.update(visible=True), gr.update()) ), inputs=[self.mask_upload], outputs=[self.mask_image_group, self.mask_image], show_progress=False, ) if self.upload_independent_img_in_img2img is not None: self.upload_independent_img_in_img2img.change( fn=lambda checked: ( # Uncheck `upload_mask` when not using independent input. gr.update(visible=False, value=False) if not checked else gr.update(visible=True) ), inputs=[self.upload_independent_img_in_img2img], outputs=[self.mask_upload], show_progress=False, ) def register_sync_batch_dir(self): def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir): if batch_dir: return batch_dir elif fallback_dir: return fallback_dir else: return fallback_fallback_dir batch_dirs = [ self.batch_image_dir, ControlNetUiGroup.global_batch_input_dir, ControlNetUiGroup.a1111_context.img2img_batch_input_dir, ] for batch_dir_comp in batch_dirs: subscriber = getattr(batch_dir_comp, "blur", None) if subscriber is None: continue subscriber( fn=determine_batch_dir, inputs=batch_dirs, outputs=[self.batch_image_dir_state], queue=False, ) ControlNetUiGroup.a1111_context.img2img_batch_output_dir.blur( fn=lambda a: a, inputs=[ControlNetUiGroup.a1111_context.img2img_batch_output_dir], outputs=[self.output_dir_state], queue=False, ) def register_clear_preview(self): def clear_preview(x): if x: logger.info("Preview as input is cancelled.") return gr.update(value=False), gr.update(value=None) for comp in ( self.pixel_perfect, self.module, self.image, self.processor_res, self.threshold_a, self.threshold_b, self.upload_independent_img_in_img2img, ): event_subscribers = [] if hasattr(comp, "edit"): event_subscribers.append(comp.edit) elif hasattr(comp, "click"): event_subscribers.append(comp.click) elif isinstance(comp, gr.Slider) and hasattr(comp, "release"): event_subscribers.append(comp.release) elif hasattr(comp, "change"): event_subscribers.append(comp.change) if hasattr(comp, "clear"): event_subscribers.append(comp.clear) for event_subscriber in event_subscribers: event_subscriber( fn=clear_preview, inputs=self.use_preview_as_input, outputs=[self.use_preview_as_input, self.generated_image], ) def register_multi_images_upload(self): """Register callbacks on merge tab multiple images upload.""" self.merge_clear_button.click( fn=lambda: [], inputs=[], outputs=[self.merge_gallery], ).then( fn=lambda x: gr.update(value=x + 1), inputs=[self.update_unit_counter], outputs=[self.update_unit_counter], ) def upload_file(files, current_files): return {file_d["name"] for file_d in current_files} | { file.name for file in files } self.merge_upload_button.upload( upload_file, inputs=[self.merge_upload_button, self.merge_gallery], outputs=[self.merge_gallery], queue=False, ).then( fn=lambda x: gr.update(value=x + 1), inputs=[self.update_unit_counter], outputs=[self.update_unit_counter], ) def register_core_callbacks(self): """Register core callbacks that only involves gradio components defined within this ui group.""" self.register_webcam_toggle() self.register_webcam_mirror_toggle() self.register_refresh_all_models() self.register_build_sliders() self.register_shift_preview() self.register_shift_upload_mask() self.register_create_canvas() self.register_clear_preview() self.register_multi_images_upload() self.openpose_editor.register_callbacks( self.generated_image, self.use_preview_as_input, self.model, ) assert self.type_filter is not None self.preset_panel.register_callbacks( self, self.type_filter, *[ getattr(self, key) for key in vars(external_code.ControlNetUnit()).keys() ], ) self.advanced_weight_control.register_callbacks( self.weight, self.advanced_weighting, self.type_filter, self.update_unit_counter, ) if self.is_img2img: self.register_img2img_same_input() def register_callbacks(self): """Register callbacks that involves A1111 context gradio components.""" # Prevent infinite recursion. if self.callbacks_registered: return self.callbacks_registered = True self.register_sd_version_changed() self.register_send_dimensions() self.register_run_annotator() self.register_sync_batch_dir() if self.is_img2img: self.register_shift_crop_input_image() else: self.register_shift_hr_options() @staticmethod def register_input_mode_sync(ui_groups: List["ControlNetUiGroup"]): """ - ui_group.input_mode should be updated when user switch tabs. - Loopback checkbox should only be visible if at least one ControlNet unit is set to batch mode. Argument: ui_groups: All ControlNetUiGroup instances defined in current Script context. Returns: None """ if not ui_groups: return for ui_group in ui_groups: batch_fn = lambda: InputMode.BATCH simple_fn = lambda: InputMode.SIMPLE merge_fn = lambda: InputMode.MERGE for input_tab, fn in ( (ui_group.upload_tab, simple_fn), (ui_group.batch_tab, batch_fn), (ui_group.merge_tab, merge_fn), ): # Sync input_mode. input_tab.select( fn=fn, inputs=[], outputs=[ui_group.input_mode], show_progress=False, ).then( # Update visibility of loopback checkbox. fn=lambda *mode_values: ( ( gr.update( visible=any(m == InputMode.BATCH for m in mode_values) ), ) * len(ui_groups) ), inputs=[g.input_mode for g in ui_groups], outputs=[g.loopback for g in ui_groups], show_progress=False, ) @staticmethod def reset(): ControlNetUiGroup.a1111_context = A1111Context() ControlNetUiGroup.all_ui_groups = [] @staticmethod def try_register_all_callbacks(): unit_count = shared.opts.data.get("control_net_unit_count", 3) all_unit_count = unit_count * 2 # txt2img + img2img. if ( # All A1111 components ControlNet units care about are all registered. ControlNetUiGroup.a1111_context.ui_initialized and all_unit_count == len(ControlNetUiGroup.all_ui_groups) and all( g.ui_initialized and (not g.callbacks_registered) for g in ControlNetUiGroup.all_ui_groups ) ): for ui_group in ControlNetUiGroup.all_ui_groups: ui_group.register_callbacks() ControlNetUiGroup.register_input_mode_sync( [g for g in ControlNetUiGroup.all_ui_groups if g.is_img2img] ) ControlNetUiGroup.register_input_mode_sync( [g for g in ControlNetUiGroup.all_ui_groups if not g.is_img2img] ) logger.info("ControlNet UI callback registered.") @staticmethod def on_after_component(component, **_kwargs): """Register the A1111 component.""" if getattr(component, "elem_id", None) == "img2img_batch_inpaint_mask_dir": ControlNetUiGroup.global_batch_input_dir.render() return ControlNetUiGroup.a1111_context.set_component(component) ControlNetUiGroup.try_register_all_callbacks()