Spaces:
Runtime error
Runtime error
import itertools | |
import torch | |
from typing import Union, Dict | |
from .ipadapter_model import ImageEmbed, IPAdapterModel | |
from ..enums import StableDiffusionVersion, TransformerID | |
def get_block(model, flag): | |
return { | |
"input": model.input_blocks, | |
"middle": [model.middle_block], | |
"output": model.output_blocks, | |
}[flag] | |
def attn_forward_hacked(self, x, context=None, **kwargs): | |
batch_size, sequence_length, inner_dim = x.shape | |
h = self.heads | |
head_dim = inner_dim // h | |
if context is None: | |
context = x | |
q = self.to_q(x) | |
k = self.to_k(context) | |
v = self.to_v(context) | |
del context | |
q, k, v = map( | |
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2), | |
(q, k, v), | |
) | |
out = torch.nn.functional.scaled_dot_product_attention( | |
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False | |
) | |
out = out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) | |
del k, v | |
for f in self.ipadapter_hacks: | |
out = out + f(self, x, q) | |
del q, x | |
return self.to_out(out) | |
all_hacks = {} | |
current_model = None | |
def hack_blk(block, function, type): | |
if not hasattr(block, "ipadapter_hacks"): | |
block.ipadapter_hacks = [] | |
if len(block.ipadapter_hacks) == 0: | |
all_hacks[block] = block.forward | |
block.forward = attn_forward_hacked.__get__(block, type) | |
block.ipadapter_hacks.append(function) | |
return | |
def set_model_attn2_replace( | |
model, | |
target_cls, | |
function, | |
transformer_id: TransformerID, | |
): | |
block = get_block(model, transformer_id.block_type.value) | |
module = ( | |
block[transformer_id.block_id][1] | |
.transformer_blocks[transformer_id.block_index] | |
.attn2 | |
) | |
hack_blk(module, function, target_cls) | |
def clear_all_ip_adapter(): | |
global all_hacks, current_model | |
for k, v in all_hacks.items(): | |
k.forward = v | |
k.ipadapter_hacks = [] | |
all_hacks = {} | |
current_model = None | |
return | |
class PlugableIPAdapter(torch.nn.Module): | |
def __init__(self, ipadapter: IPAdapterModel): | |
super().__init__() | |
self.ipadapter = ipadapter | |
self.disable_memory_management = True | |
self.dtype = None | |
self.weight: Union[float, Dict[int, float]] = 1.0 | |
self.cache = None | |
self.p_start = 0.0 | |
self.p_end = 1.0 | |
def reset(self): | |
self.cache = {} | |
def hook( | |
self, model, preprocessor_outputs, weight, start, end, dtype=torch.float32 | |
): | |
global current_model | |
current_model = model | |
self.p_start = start | |
self.p_end = end | |
self.cache = {} | |
self.weight = weight | |
device = torch.device("cpu") | |
self.dtype = dtype | |
self.ipadapter.to(device, dtype=self.dtype) | |
if isinstance(preprocessor_outputs, (list, tuple)): | |
preprocessor_outputs = preprocessor_outputs | |
else: | |
preprocessor_outputs = [preprocessor_outputs] | |
self.image_emb = ImageEmbed.average_of( | |
*[self.ipadapter.get_image_emb(o) for o in preprocessor_outputs] | |
) | |
if self.ipadapter.is_sdxl: | |
sd_version = StableDiffusionVersion.SDXL | |
from sgm.modules.attention import CrossAttention | |
else: | |
sd_version = StableDiffusionVersion.SD1x | |
from ldm.modules.attention import CrossAttention | |
input_ids, output_ids, middle_ids = sd_version.transformer_ids | |
for i, transformer_id in enumerate( | |
itertools.chain(input_ids, output_ids, middle_ids) | |
): | |
set_model_attn2_replace( | |
model, | |
CrossAttention, | |
self.patch_forward(i, transformer_id.transformer_index), | |
transformer_id, | |
) | |
def weight_on_transformer(self, transformer_index: int) -> float: | |
if isinstance(self.weight, dict): | |
return self.weight.get(transformer_index, 0.0) | |
else: | |
assert isinstance(self.weight, (float, int)) | |
return self.weight | |
def call_ip(self, key: str, feat, device): | |
if key in self.cache: | |
return self.cache[key] | |
else: | |
ip = self.ipadapter.ip_layers.to_kvs[key](feat).to(device) | |
self.cache[key] = ip | |
return ip | |
def patch_forward(self, number: int, transformer_index: int): | |
def forward(attn_blk, x, q): | |
batch_size, sequence_length, inner_dim = x.shape | |
h = attn_blk.heads | |
head_dim = inner_dim // h | |
weight = self.weight_on_transformer(transformer_index) | |
current_sampling_percent = getattr( | |
current_model, "current_sampling_percent", 0.5 | |
) | |
if ( | |
current_sampling_percent < self.p_start | |
or current_sampling_percent > self.p_end | |
or weight == 0.0 | |
): | |
return 0.0 | |
k_key = f"{number * 2 + 1}_to_k_ip" | |
v_key = f"{number * 2 + 1}_to_v_ip" | |
cond_uncond_image_emb = self.image_emb.eval(current_model.cond_mark) | |
ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device) | |
ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device) | |
ip_k, ip_v = map( | |
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2), | |
(ip_k, ip_v), | |
) | |
assert ip_k.dtype == ip_v.dtype | |
# On MacOS, q can be float16 instead of float32. | |
# https://github.com/Mikubill/sd-webui-controlnet/issues/2208 | |
if q.dtype != ip_k.dtype: | |
ip_k = ip_k.to(dtype=q.dtype) | |
ip_v = ip_v.to(dtype=q.dtype) | |
ip_out = torch.nn.functional.scaled_dot_product_attention( | |
q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False | |
) | |
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) | |
return ip_out * weight | |
return forward | |