bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
import itertools
import torch
from typing import Union, Dict
from .ipadapter_model import ImageEmbed, IPAdapterModel
from ..enums import StableDiffusionVersion, TransformerID
def get_block(model, flag):
return {
"input": model.input_blocks,
"middle": [model.middle_block],
"output": model.output_blocks,
}[flag]
def attn_forward_hacked(self, x, context=None, **kwargs):
batch_size, sequence_length, inner_dim = x.shape
h = self.heads
head_dim = inner_dim // h
if context is None:
context = x
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
del context
q, k, v = map(
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
)
out = out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
del k, v
for f in self.ipadapter_hacks:
out = out + f(self, x, q)
del q, x
return self.to_out(out)
all_hacks = {}
current_model = None
def hack_blk(block, function, type):
if not hasattr(block, "ipadapter_hacks"):
block.ipadapter_hacks = []
if len(block.ipadapter_hacks) == 0:
all_hacks[block] = block.forward
block.forward = attn_forward_hacked.__get__(block, type)
block.ipadapter_hacks.append(function)
return
def set_model_attn2_replace(
model,
target_cls,
function,
transformer_id: TransformerID,
):
block = get_block(model, transformer_id.block_type.value)
module = (
block[transformer_id.block_id][1]
.transformer_blocks[transformer_id.block_index]
.attn2
)
hack_blk(module, function, target_cls)
def clear_all_ip_adapter():
global all_hacks, current_model
for k, v in all_hacks.items():
k.forward = v
k.ipadapter_hacks = []
all_hacks = {}
current_model = None
return
class PlugableIPAdapter(torch.nn.Module):
def __init__(self, ipadapter: IPAdapterModel):
super().__init__()
self.ipadapter = ipadapter
self.disable_memory_management = True
self.dtype = None
self.weight: Union[float, Dict[int, float]] = 1.0
self.cache = None
self.p_start = 0.0
self.p_end = 1.0
def reset(self):
self.cache = {}
@torch.no_grad()
def hook(
self, model, preprocessor_outputs, weight, start, end, dtype=torch.float32
):
global current_model
current_model = model
self.p_start = start
self.p_end = end
self.cache = {}
self.weight = weight
device = torch.device("cpu")
self.dtype = dtype
self.ipadapter.to(device, dtype=self.dtype)
if isinstance(preprocessor_outputs, (list, tuple)):
preprocessor_outputs = preprocessor_outputs
else:
preprocessor_outputs = [preprocessor_outputs]
self.image_emb = ImageEmbed.average_of(
*[self.ipadapter.get_image_emb(o) for o in preprocessor_outputs]
)
if self.ipadapter.is_sdxl:
sd_version = StableDiffusionVersion.SDXL
from sgm.modules.attention import CrossAttention
else:
sd_version = StableDiffusionVersion.SD1x
from ldm.modules.attention import CrossAttention
input_ids, output_ids, middle_ids = sd_version.transformer_ids
for i, transformer_id in enumerate(
itertools.chain(input_ids, output_ids, middle_ids)
):
set_model_attn2_replace(
model,
CrossAttention,
self.patch_forward(i, transformer_id.transformer_index),
transformer_id,
)
def weight_on_transformer(self, transformer_index: int) -> float:
if isinstance(self.weight, dict):
return self.weight.get(transformer_index, 0.0)
else:
assert isinstance(self.weight, (float, int))
return self.weight
def call_ip(self, key: str, feat, device):
if key in self.cache:
return self.cache[key]
else:
ip = self.ipadapter.ip_layers.to_kvs[key](feat).to(device)
self.cache[key] = ip
return ip
@torch.no_grad()
def patch_forward(self, number: int, transformer_index: int):
@torch.no_grad()
def forward(attn_blk, x, q):
batch_size, sequence_length, inner_dim = x.shape
h = attn_blk.heads
head_dim = inner_dim // h
weight = self.weight_on_transformer(transformer_index)
current_sampling_percent = getattr(
current_model, "current_sampling_percent", 0.5
)
if (
current_sampling_percent < self.p_start
or current_sampling_percent > self.p_end
or weight == 0.0
):
return 0.0
k_key = f"{number * 2 + 1}_to_k_ip"
v_key = f"{number * 2 + 1}_to_v_ip"
cond_uncond_image_emb = self.image_emb.eval(current_model.cond_mark)
ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device)
ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device)
ip_k, ip_v = map(
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2),
(ip_k, ip_v),
)
assert ip_k.dtype == ip_v.dtype
# On MacOS, q can be float16 instead of float32.
# https://github.com/Mikubill/sd-webui-controlnet/issues/2208
if q.dtype != ip_k.dtype:
ip_k = ip_k.to(dtype=q.dtype)
ip_v = ip_v.to(dtype=q.dtype)
ip_out = torch.nn.functional.scaled_dot_product_attention(
q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
return ip_out * weight
return forward