Spaces:
Paused
Paused
import sys | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import weakref | |
from typing import Union, TYPE_CHECKING, Optional, Tuple | |
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer | |
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention | |
from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule | |
if TYPE_CHECKING: | |
from toolkit.stable_diffusion_model import StableDiffusion | |
from toolkit.custom_adapter import CustomAdapter | |
class TEAugAdapterCLIPAttention(nn.Module): | |
"""Multi-headed attention from 'Attention Is All You Need' paper""" | |
def __init__(self, attn_module: 'CLIPAttention', adapter: 'TEAugAdapter'): | |
super().__init__() | |
self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
self.attn_module_ref: weakref.ref = weakref.ref(attn_module) | |
self.k_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim) | |
self.v_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim) | |
# copy the weights from the original module | |
self.k_proj_adapter.weight.data = attn_module.k_proj.weight.data.clone() * 0.01 | |
self.v_proj_adapter.weight.data = attn_module.v_proj.weight.data.clone() * 0.01 | |
#reset the bias | |
self.k_proj_adapter.bias.data = attn_module.k_proj.bias.data.clone() * 0.001 | |
self.v_proj_adapter.bias.data = attn_module.v_proj.bias.data.clone() * 0.001 | |
self.zipper = ZipperModule( | |
in_size=attn_module.embed_dim, | |
in_tokens=77 * 2, | |
out_size=attn_module.embed_dim, | |
out_tokens=77, | |
hidden_size=attn_module.embed_dim, | |
hidden_tokens=77, | |
) | |
# self.k_proj_adapter.weight.data = torch.zeros_like(attn_module.k_proj.weight.data) | |
# self.v_proj_adapter.weight.data = torch.zeros_like(attn_module.v_proj.weight.data) | |
# #reset the bias | |
# self.k_proj_adapter.bias.data = torch.zeros_like(attn_module.k_proj.bias.data) | |
# self.v_proj_adapter.bias.data = torch.zeros_like(attn_module.v_proj.bias.data) | |
# replace the original forward with our forward | |
self.original_forward = attn_module.forward | |
attn_module.forward = self.forward | |
def is_active(self): | |
return self.adapter_ref().is_active | |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
causal_attention_mask: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
"""Input shape: Batch x Time x Channel""" | |
attn_module = self.attn_module_ref() | |
bsz, tgt_len, embed_dim = hidden_states.size() | |
# get query proj | |
query_states = attn_module.q_proj(hidden_states) * attn_module.scale | |
key_states = attn_module._shape(attn_module.k_proj(hidden_states), -1, bsz) | |
value_states = attn_module._shape(attn_module.v_proj(hidden_states), -1, bsz) | |
proj_shape = (bsz * attn_module.num_heads, -1, attn_module.head_dim) | |
query_states = attn_module._shape(query_states, tgt_len, bsz).view(*proj_shape) | |
key_states = key_states.view(*proj_shape) | |
value_states = value_states.view(*proj_shape) | |
src_len = key_states.size(1) | |
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
if attn_weights.size() != (bsz * attn_module.num_heads, tgt_len, src_len): | |
raise ValueError( | |
f"Attention weights should be of size {(bsz * attn_module.num_heads, tgt_len, src_len)}, but is" | |
f" {attn_weights.size()}" | |
) | |
# apply the causal_attention_mask first | |
if causal_attention_mask is not None: | |
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
raise ValueError( | |
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" | |
f" {causal_attention_mask.size()}" | |
) | |
attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + causal_attention_mask | |
attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len) | |
if attention_mask is not None: | |
if attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
raise ValueError( | |
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" | |
) | |
attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attention_mask | |
attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len) | |
attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
if output_attentions: | |
# this operation is a bit akward, but it's required to | |
# make sure that attn_weights keeps its gradient. | |
# In order to do so, attn_weights have to reshaped | |
# twice and have to be reused in the following | |
attn_weights_reshaped = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) | |
attn_weights = attn_weights_reshaped.view(bsz * attn_module.num_heads, tgt_len, src_len) | |
else: | |
attn_weights_reshaped = None | |
attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training) | |
attn_output = torch.bmm(attn_probs, value_states) | |
if attn_output.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim): | |
raise ValueError( | |
f"`attn_output` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is" | |
f" {attn_output.size()}" | |
) | |
attn_output = attn_output.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim) | |
attn_output = attn_output.transpose(1, 2) | |
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) | |
adapter: 'CustomAdapter' = self.adapter_ref().adapter_ref() | |
if self.adapter_ref().is_active and adapter.conditional_embeds is not None: | |
# apply the adapter | |
if adapter.is_unconditional_run: | |
embeds = adapter.unconditional_embeds | |
else: | |
embeds = adapter.conditional_embeds | |
# if the shape is not the same on batch, we are doing cfg and need to concat unconditional as well | |
if embeds.size(0) != bsz: | |
embeds = torch.cat([adapter.unconditional_embeds, embeds], dim=0) | |
key_states_raw = self.k_proj_adapter(embeds) | |
key_states = attn_module._shape(key_states_raw, -1, bsz) | |
value_states_raw = self.v_proj_adapter(embeds) | |
value_states = attn_module._shape(value_states_raw, -1, bsz) | |
key_states = key_states.view(*proj_shape) | |
value_states = value_states.view(*proj_shape) | |
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training) | |
attn_output_adapter = torch.bmm(attn_probs, value_states) | |
if attn_output_adapter.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim): | |
raise ValueError( | |
f"`attn_output_adapter` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is" | |
f" {attn_output_adapter.size()}" | |
) | |
attn_output_adapter = attn_output_adapter.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim) | |
attn_output_adapter = attn_output_adapter.transpose(1, 2) | |
attn_output_adapter = attn_output_adapter.reshape(bsz, tgt_len, embed_dim) | |
attn_output_adapter = self.zipper(torch.cat([attn_output_adapter, attn_output], dim=1)) | |
# attn_output_adapter = attn_module.out_proj(attn_output_adapter) | |
attn_output = attn_output + attn_output_adapter | |
attn_output = attn_module.out_proj(attn_output) | |
return attn_output, attn_weights_reshaped | |
class TEAugAdapter(torch.nn.Module): | |
def __init__( | |
self, | |
adapter: 'CustomAdapter', | |
sd: 'StableDiffusion', | |
): | |
super(TEAugAdapter, self).__init__() | |
self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
self.sd_ref: weakref.ref = weakref.ref(sd) | |
if isinstance(sd.text_encoder, list): | |
raise ValueError("Dual text encoders is not yet supported") | |
# dim will come from text encoder | |
# dim = sd.unet.config['cross_attention_dim'] | |
text_encoder: CLIPTextModel = sd.text_encoder | |
dim = text_encoder.config.hidden_size | |
clip_encoder: CLIPEncoder = text_encoder.text_model.encoder | |
# dim = clip_encoder.layers[-1].self_attn | |
if hasattr(adapter.vision_encoder.config, 'hidden_sizes'): | |
embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1] | |
else: | |
embedding_dim = adapter.vision_encoder.config.hidden_size | |
image_encoder_state_dict = adapter.vision_encoder.state_dict() | |
# max_seq_len = CLIP tokens + CLS token | |
in_tokens = 257 | |
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: | |
# clip | |
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) | |
if adapter.config.image_encoder_arch.startswith('convnext'): | |
in_tokens = 16 * 16 | |
embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1] | |
out_tokens = adapter.config.num_tokens if adapter.config.num_tokens > 0 else in_tokens | |
self.image_proj_model = ZipperModule( | |
in_size=embedding_dim, | |
in_tokens=in_tokens, | |
out_size=dim, | |
out_tokens=out_tokens, | |
hidden_size=dim, | |
hidden_tokens=out_tokens, | |
) | |
# init adapter modules | |
attn_procs = {} | |
for idx, layer in enumerate(clip_encoder.layers): | |
name = f"clip_attention.{idx}" | |
attn_procs[name] = TEAugAdapterCLIPAttention( | |
layer.self_attn, | |
self | |
) | |
self.adapter_modules = torch.nn.ModuleList(list(attn_procs.values())) | |
# make a getter to see if is active | |
def is_active(self): | |
return self.adapter_ref().is_active | |
def forward(self, input): | |
# # apply the adapter | |
input = self.image_proj_model(input) | |
# self.embeds = input | |
return input | |