Spaces:
Runtime error
Runtime error
File size: 6,143 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import itertools
import torch
from typing import Union, Dict
from .ipadapter_model import ImageEmbed, IPAdapterModel
from ..enums import StableDiffusionVersion, TransformerID
def get_block(model, flag):
return {
"input": model.input_blocks,
"middle": [model.middle_block],
"output": model.output_blocks,
}[flag]
def attn_forward_hacked(self, x, context=None, **kwargs):
batch_size, sequence_length, inner_dim = x.shape
h = self.heads
head_dim = inner_dim // h
if context is None:
context = x
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
del context
q, k, v = map(
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
)
out = out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
del k, v
for f in self.ipadapter_hacks:
out = out + f(self, x, q)
del q, x
return self.to_out(out)
all_hacks = {}
current_model = None
def hack_blk(block, function, type):
if not hasattr(block, "ipadapter_hacks"):
block.ipadapter_hacks = []
if len(block.ipadapter_hacks) == 0:
all_hacks[block] = block.forward
block.forward = attn_forward_hacked.__get__(block, type)
block.ipadapter_hacks.append(function)
return
def set_model_attn2_replace(
model,
target_cls,
function,
transformer_id: TransformerID,
):
block = get_block(model, transformer_id.block_type.value)
module = (
block[transformer_id.block_id][1]
.transformer_blocks[transformer_id.block_index]
.attn2
)
hack_blk(module, function, target_cls)
def clear_all_ip_adapter():
global all_hacks, current_model
for k, v in all_hacks.items():
k.forward = v
k.ipadapter_hacks = []
all_hacks = {}
current_model = None
return
class PlugableIPAdapter(torch.nn.Module):
def __init__(self, ipadapter: IPAdapterModel):
super().__init__()
self.ipadapter = ipadapter
self.disable_memory_management = True
self.dtype = None
self.weight: Union[float, Dict[int, float]] = 1.0
self.cache = None
self.p_start = 0.0
self.p_end = 1.0
def reset(self):
self.cache = {}
@torch.no_grad()
def hook(
self, model, preprocessor_outputs, weight, start, end, dtype=torch.float32
):
global current_model
current_model = model
self.p_start = start
self.p_end = end
self.cache = {}
self.weight = weight
device = torch.device("cpu")
self.dtype = dtype
self.ipadapter.to(device, dtype=self.dtype)
if isinstance(preprocessor_outputs, (list, tuple)):
preprocessor_outputs = preprocessor_outputs
else:
preprocessor_outputs = [preprocessor_outputs]
self.image_emb = ImageEmbed.average_of(
*[self.ipadapter.get_image_emb(o) for o in preprocessor_outputs]
)
if self.ipadapter.is_sdxl:
sd_version = StableDiffusionVersion.SDXL
from sgm.modules.attention import CrossAttention
else:
sd_version = StableDiffusionVersion.SD1x
from ldm.modules.attention import CrossAttention
input_ids, output_ids, middle_ids = sd_version.transformer_ids
for i, transformer_id in enumerate(
itertools.chain(input_ids, output_ids, middle_ids)
):
set_model_attn2_replace(
model,
CrossAttention,
self.patch_forward(i, transformer_id.transformer_index),
transformer_id,
)
def weight_on_transformer(self, transformer_index: int) -> float:
if isinstance(self.weight, dict):
return self.weight.get(transformer_index, 0.0)
else:
assert isinstance(self.weight, (float, int))
return self.weight
def call_ip(self, key: str, feat, device):
if key in self.cache:
return self.cache[key]
else:
ip = self.ipadapter.ip_layers.to_kvs[key](feat).to(device)
self.cache[key] = ip
return ip
@torch.no_grad()
def patch_forward(self, number: int, transformer_index: int):
@torch.no_grad()
def forward(attn_blk, x, q):
batch_size, sequence_length, inner_dim = x.shape
h = attn_blk.heads
head_dim = inner_dim // h
weight = self.weight_on_transformer(transformer_index)
current_sampling_percent = getattr(
current_model, "current_sampling_percent", 0.5
)
if (
current_sampling_percent < self.p_start
or current_sampling_percent > self.p_end
or weight == 0.0
):
return 0.0
k_key = f"{number * 2 + 1}_to_k_ip"
v_key = f"{number * 2 + 1}_to_v_ip"
cond_uncond_image_emb = self.image_emb.eval(current_model.cond_mark)
ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device)
ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device)
ip_k, ip_v = map(
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2),
(ip_k, ip_v),
)
assert ip_k.dtype == ip_v.dtype
# On MacOS, q can be float16 instead of float32.
# https://github.com/Mikubill/sd-webui-controlnet/issues/2208
if q.dtype != ip_k.dtype:
ip_k = ip_k.to(dtype=q.dtype)
ip_v = ip_v.to(dtype=q.dtype)
ip_out = torch.nn.functional.scaled_dot_product_attention(
q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
return ip_out * weight
return forward
|