import gradio as gr import matplotlib.pyplot as plt from matplotlib.patches import Patch import io import json from PIL import Image from typing import List from scripts.enums import StableDiffusionVersion from scripts.global_state import get_sd_version from scripts.ipadapter.weight import calc_weights INPUT_BLOCK_COLOR = "#61bdee" MIDDLE_BLOCK_COLOR = "#e2e2e2" OUTPUT_BLOCK_COLOR = "#dc6e55" def get_bar_colors( sd_version: StableDiffusionVersion, input_color, middle_color, output_color ): middle_block_idx = 4 if sd_version == StableDiffusionVersion.SDXL else 6 def get_color(idx): if idx < middle_block_idx: return input_color elif idx == middle_block_idx: return middle_color else: return output_color return [get_color(i) for i in range(sd_version.transformer_block_num)] def plot_weights( numbers: List[float], colors: List[str], ): # Create a bar chart plt.figure(figsize=(8, 4)) plt.bar(range(len(numbers)), numbers, color=colors) plt.xlabel("Transformer Index") plt.ylabel("Weight") plt.legend( handles=[ Patch(color=color, label=label) for color, label in ( (INPUT_BLOCK_COLOR, "Input Block"), (MIDDLE_BLOCK_COLOR, "Middle Block"), (OUTPUT_BLOCK_COLOR, "Output Block"), ) ], loc="best", ) # Save the plot to a BytesIO buffer buffer = io.BytesIO() plt.savefig(buffer, format="png") plt.close() buffer.seek(0) # Convert the buffer to a PIL image and return it image = Image.open(buffer) return image class AdvancedWeightControl: def __init__(self): self.group = None self.weight_type = None self.weight_plot = None self.weight_editor = None self.weight_composition = None def render(self): with gr.Group(visible=False) as self.group: with gr.Row(): self.weight_type = gr.Dropdown( choices=[ "normal", "ease in", "ease out", "ease in-out", "reverse in-out", "weak input", "weak output", "weak middle", "strong middle", "style transfer", "composition", "strong style transfer", "style and composition", "strong style and composition", ], label="Weight Type", value="normal", ) self.weight_composition = gr.Slider( label="Composition Weight", minimum=0, maximum=2.0, value=1.0, step=0.01, visible=False, ) self.weight_editor = gr.Textbox(label="Weights", visible=False) self.weight_plot = gr.Image( value=None, label="Weight Plot", interactive=False, visible=False, ) def register_callbacks( self, weight_input: gr.Slider, advanced_weighting: gr.State, control_type: gr.Radio, update_unit_counter: gr.Number, ): def advanced_weighting_supported(control_type: str) -> bool: return control_type in ("IP-Adapter", "Instant-ID") self.weight_type.change( fn=lambda weight_type: gr.update( visible=weight_type in ("style and composition", "strong style and composition") ), inputs=[self.weight_type], outputs=[self.weight_composition], ) def update_weight_textbox( control_type: str, weight_type: str, weight: float, weight_composition: float, ): if not advanced_weighting_supported(control_type): return gr.update() sd_version = get_sd_version() weights = calc_weights(weight_type, weight, sd_version, weight_composition) return gr.update(value=str([round(w, 2) for w in weights]), visible=True) trigger_inputs = [self.weight_type, weight_input, self.weight_composition] for trigger_input in trigger_inputs: trigger_input.change( fn=update_weight_textbox, inputs=[ control_type, self.weight_type, weight_input, self.weight_composition, ], outputs=[self.weight_editor], ) def update_plot(weights_string: str): try: weights = json.loads(weights_string) assert isinstance(weights, list) except Exception: return gr.update(visible=False) sd_version = get_sd_version() weight_plot = plot_weights( weights, get_bar_colors( sd_version, input_color=INPUT_BLOCK_COLOR, middle_color=MIDDLE_BLOCK_COLOR, output_color=OUTPUT_BLOCK_COLOR, ), ) return gr.update(value=weight_plot, visible=True) def update_advanced_weighting(weights_string: str): try: weights = json.loads(weights_string) assert isinstance(weights, list) except Exception: return None return weights self.weight_editor.change( fn=update_plot, inputs=[self.weight_editor], outputs=[self.weight_plot], ) self.weight_editor.change( fn=update_advanced_weighting, inputs=[self.weight_editor], outputs=[advanced_weighting], ).then( fn=lambda x: gr.update(value=x + 1), inputs=[update_unit_counter], outputs=[update_unit_counter], ) # Necessary to flush gr.State change to unit state. # TODO: Expose advanced weighting control for other control types. def control_type_change(control_type: str, old_weights): supported = advanced_weighting_supported(control_type) if supported: return ( gr.update(visible=supported), old_weights, gr.update(), gr.update(), ) else: return ( gr.update(visible=supported), None, gr.update(visible=False), gr.update(visible=False), ) control_type.change( fn=control_type_change, inputs=[control_type, advanced_weighting], outputs=[ self.group, advanced_weighting, self.weight_editor, self.weight_plot, ], )