|
import gradio as gr |
|
|
|
import sys |
|
import traceback |
|
|
|
from typing import Any |
|
from functools import partial |
|
|
|
from modules import script_callbacks, scripts |
|
from ldm_patched.contrib.external_freelunch import FreeU_V2 |
|
|
|
|
|
opFreeU_V2 = FreeU_V2() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FreeUForForge(scripts.Script): |
|
sorting_priority = 12 |
|
|
|
def title(self): |
|
return "FreeU Integrated" |
|
|
|
def show(self, is_img2img): |
|
|
|
return scripts.AlwaysVisible |
|
|
|
def ui(self, *args, **kwargs): |
|
with gr.Accordion(open=False, label=self.title()): |
|
freeu_enabled = gr.Checkbox(label='Enabled', value=False) |
|
freeu_b1 = gr.Slider(label='B1', minimum=0, maximum=2, step=0.01, value=1.01) |
|
freeu_b2 = gr.Slider(label='B2', minimum=0, maximum=2, step=0.01, value=1.02) |
|
freeu_s1 = gr.Slider(label='S1', minimum=0, maximum=4, step=0.01, value=0.99) |
|
freeu_s2 = gr.Slider(label='S2', minimum=0, maximum=4, step=0.01, value=0.95) |
|
|
|
return freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 |
|
|
|
def process_before_every_sampling(self, p, *script_args, **kwargs): |
|
|
|
|
|
|
|
freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 = script_args |
|
|
|
xyz = getattr(p, "_freeu_xyz", {}) |
|
if "freeu_enabled" in xyz: |
|
freeu_enabled = xyz["freeu_enabled"] == "True" |
|
if "freeu_b1" in xyz: |
|
freeu_b1 = xyz["freeu_b1"] |
|
if "freeu_b2" in xyz: |
|
freeu_b2 = xyz["freeu_b2"] |
|
if "freeu_s1" in xyz: |
|
freeu_s1 = xyz["freeu_s1"] |
|
if "freeu_s2" in xyz: |
|
freeu_s2 = xyz["freeu_s2"] |
|
|
|
if not freeu_enabled: |
|
return |
|
|
|
unet = p.sd_model.forge_objects.unet |
|
|
|
|
|
unet = opFreeU_V2.patch(unet, freeu_b1, freeu_b2, freeu_s1, freeu_s2)[0] |
|
|
|
p.sd_model.forge_objects.unet = unet |
|
|
|
|
|
|
|
p.extra_generation_params.update(dict( |
|
freeu_enabled=freeu_enabled, |
|
freeu_b1=freeu_b1, |
|
freeu_b2=freeu_b2, |
|
freeu_s1=freeu_s1, |
|
freeu_s2=freeu_s2, |
|
)) |
|
|
|
return |
|
|
|
def set_value(p, x: Any, xs: Any, *, field: str): |
|
if not hasattr(p, "_freeu_xyz"): |
|
p._freeu_xyz = {} |
|
p._freeu_xyz[field] = x |
|
|
|
def make_axis_on_xyz_grid(): |
|
xyz_grid = None |
|
for script in scripts.scripts_data: |
|
if script.script_class.__module__ == "xyz_grid.py": |
|
xyz_grid = script.module |
|
break |
|
|
|
if xyz_grid is None: |
|
return |
|
|
|
axis = [ |
|
xyz_grid.AxisOption( |
|
"FreeU Enabled", |
|
str, |
|
partial(set_value, field="freeu_enabled"), |
|
choices=lambda: ["True", "False"] |
|
), |
|
xyz_grid.AxisOption( |
|
"FreeU B1", |
|
float, |
|
partial(set_value, field="freeu_b1"), |
|
), |
|
xyz_grid.AxisOption( |
|
"FreeU B2", |
|
float, |
|
partial(set_value, field="freeu_b2"), |
|
), |
|
xyz_grid.AxisOption( |
|
"FreeU S1", |
|
float, |
|
partial(set_value, field="freeu_s1"), |
|
), |
|
xyz_grid.AxisOption( |
|
"FreeU S2", |
|
float, |
|
partial(set_value, field="freeu_s2"), |
|
), |
|
] |
|
|
|
if not any(x.label.startswith("FreeU") for x in xyz_grid.axis_options): |
|
xyz_grid.axis_options.extend(axis) |
|
|
|
def on_before_ui(): |
|
try: |
|
make_axis_on_xyz_grid() |
|
except Exception: |
|
error = traceback.format_exc() |
|
print( |
|
f"[-] FreeU Integrated: xyz_grid error:\n{error}", |
|
file=sys.stderr, |
|
) |
|
|
|
script_callbacks.on_before_ui(on_before_ui) |
|
|