File size: 5,624 Bytes
ecc4278 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import sys
import traceback
from typing import Any
from functools import partial
from modules import script_callbacks, scripts
from ldm_patched.contrib.external_freelunch import FreeU_V2
opFreeU_V2 = FreeU_V2()
# def Fourier_filter(x, threshold, scale):
# x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
# x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
# B, C, H, W = x_freq.shape
# mask = torch.ones((B, C, H, W), device=x.device)
# crow, ccol = H // 2, W //2
# mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
# x_freq = x_freq * mask
# x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
# x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
# return x_filtered.to(x.dtype)
#
#
# def set_freeu_v2_patch(model, b1, b2, s1, s2):
# model_channels = model.model.model_config.unet_config["model_channels"]
# scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
#
# def output_block_patch(h, hsp, *args, **kwargs):
# scale = scale_dict.get(h.shape[1], None)
# if scale is not None:
# hidden_mean = h.mean(1).unsqueeze(1)
# B = hidden_mean.shape[0]
# hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
# hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
# hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / \
# (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
# h[:, :h.shape[1] // 2] = h[:, :h.shape[1] // 2] * ((scale[0] - 1) * hidden_mean + 1)
# hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
# return h, hsp
#
# m = model.clone()
# m.set_model_output_block_patch(output_block_patch)
# return m
class FreeUForForge(scripts.Script):
sorting_priority = 12
def title(self):
return "FreeU Integrated"
def show(self, is_img2img):
# make this extension visible in both txt2img and img2img tab.
return scripts.AlwaysVisible
def ui(self, *args, **kwargs):
with gr.Accordion(open=False, label=self.title()):
freeu_enabled = gr.Checkbox(label='Enabled', value=False)
freeu_b1 = gr.Slider(label='B1', minimum=0, maximum=2, step=0.01, value=1.01)
freeu_b2 = gr.Slider(label='B2', minimum=0, maximum=2, step=0.01, value=1.02)
freeu_s1 = gr.Slider(label='S1', minimum=0, maximum=4, step=0.01, value=0.99)
freeu_s2 = gr.Slider(label='S2', minimum=0, maximum=4, step=0.01, value=0.95)
return freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2
def process_before_every_sampling(self, p, *script_args, **kwargs):
# This will be called before every sampling.
# If you use highres fix, this will be called twice.
freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 = script_args
xyz = getattr(p, "_freeu_xyz", {})
if "freeu_enabled" in xyz:
freeu_enabled = xyz["freeu_enabled"] == "True"
if "freeu_b1" in xyz:
freeu_b1 = xyz["freeu_b1"]
if "freeu_b2" in xyz:
freeu_b2 = xyz["freeu_b2"]
if "freeu_s1" in xyz:
freeu_s1 = xyz["freeu_s1"]
if "freeu_s2" in xyz:
freeu_s2 = xyz["freeu_s2"]
if not freeu_enabled:
return
unet = p.sd_model.forge_objects.unet
# unet = set_freeu_v2_patch(unet, freeu_b1, freeu_b2, freeu_s1, freeu_s2)
unet = opFreeU_V2.patch(unet, freeu_b1, freeu_b2, freeu_s1, freeu_s2)[0]
p.sd_model.forge_objects.unet = unet
# Below codes will add some logs to the texts below the image outputs on UI.
# The extra_generation_params does not influence results.
p.extra_generation_params.update(dict(
freeu_enabled=freeu_enabled,
freeu_b1=freeu_b1,
freeu_b2=freeu_b2,
freeu_s1=freeu_s1,
freeu_s2=freeu_s2,
))
return
def set_value(p, x: Any, xs: Any, *, field: str):
if not hasattr(p, "_freeu_xyz"):
p._freeu_xyz = {}
p._freeu_xyz[field] = x
def make_axis_on_xyz_grid():
xyz_grid = None
for script in scripts.scripts_data:
if script.script_class.__module__ == "xyz_grid.py":
xyz_grid = script.module
break
if xyz_grid is None:
return
axis = [
xyz_grid.AxisOption(
"FreeU Enabled",
str,
partial(set_value, field="freeu_enabled"),
choices=lambda: ["True", "False"]
),
xyz_grid.AxisOption(
"FreeU B1",
float,
partial(set_value, field="freeu_b1"),
),
xyz_grid.AxisOption(
"FreeU B2",
float,
partial(set_value, field="freeu_b2"),
),
xyz_grid.AxisOption(
"FreeU S1",
float,
partial(set_value, field="freeu_s1"),
),
xyz_grid.AxisOption(
"FreeU S2",
float,
partial(set_value, field="freeu_s2"),
),
]
if not any(x.label.startswith("FreeU") for x in xyz_grid.axis_options):
xyz_grid.axis_options.extend(axis)
def on_before_ui():
try:
make_axis_on_xyz_grid()
except Exception:
error = traceback.format_exc()
print(
f"[-] FreeU Integrated: xyz_grid error:\n{error}",
file=sys.stderr,
)
script_callbacks.on_before_ui(on_before_ui)
|