Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from typing import List, NamedTuple, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from transformers.models.clip.modeling_clip import CLIPVisionModelOutput | |
from .image_proj_models import ( | |
Resampler, | |
ImageProjModel, | |
MLPProjModel, | |
MLPProjModelFaceId, | |
ProjModelFaceIdPlus, | |
) | |
class ImageEmbed(NamedTuple): | |
"""Image embed for a single image.""" | |
cond_emb: torch.Tensor | |
uncond_emb: torch.Tensor | |
def eval(self, cond_mark: torch.Tensor) -> torch.Tensor: | |
assert cond_mark.ndim == 4 | |
assert self.cond_emb.ndim == self.uncond_emb.ndim == 3 | |
assert ( | |
self.uncond_emb.shape[0] == 1 | |
or self.cond_emb.shape[0] == self.uncond_emb.shape[0] | |
) | |
assert ( | |
self.cond_emb.shape[0] == 1 or self.cond_emb.shape[0] == cond_mark.shape[0] | |
) | |
cond_mark = cond_mark[:, :, :, 0].to(self.cond_emb) | |
device = cond_mark.device | |
dtype = cond_mark.dtype | |
return self.cond_emb.to( | |
device=device, dtype=dtype | |
) * cond_mark + self.uncond_emb.to(device=device, dtype=dtype) * (1 - cond_mark) | |
def average_of(*args: List[Tuple[torch.Tensor, torch.Tensor]]) -> "ImageEmbed": | |
conds, unconds = zip(*args) | |
def average_tensors(tensors: List[torch.Tensor]) -> torch.Tensor: | |
return torch.sum(torch.stack(tensors), dim=0) / len(tensors) | |
return ImageEmbed(average_tensors(conds), average_tensors(unconds)) | |
class To_KV(torch.nn.Module): | |
def __init__(self, state_dict): | |
super().__init__() | |
self.to_kvs = nn.ModuleDict() | |
for key, value in state_dict.items(): | |
k = key.replace(".weight", "").replace(".", "_") | |
self.to_kvs[k] = nn.Linear(value.shape[1], value.shape[0], bias=False) | |
self.to_kvs[k].weight.data = value | |
class IPAdapterModel(torch.nn.Module): | |
def __init__( | |
self, | |
state_dict, | |
clip_embeddings_dim, | |
cross_attention_dim, | |
is_plus, | |
is_sdxl: bool, | |
sdxl_plus, | |
is_full, | |
is_faceid: bool, | |
is_portrait: bool, | |
is_instantid: bool, | |
is_v2: bool, | |
): | |
super().__init__() | |
self.device = "cpu" | |
self.clip_embeddings_dim = clip_embeddings_dim | |
self.cross_attention_dim = cross_attention_dim | |
self.is_plus = is_plus | |
self.is_sdxl = is_sdxl | |
self.sdxl_plus = sdxl_plus | |
self.is_full = is_full | |
self.is_v2 = is_v2 | |
self.is_faceid = is_faceid | |
self.is_instantid = is_instantid | |
self.clip_extra_context_tokens = 16 if (self.is_plus or is_portrait) else 4 | |
if is_instantid: | |
self.image_proj_model = self.init_proj_instantid() | |
elif is_faceid: | |
self.image_proj_model = self.init_proj_faceid() | |
elif self.is_plus: | |
if self.is_full: | |
self.image_proj_model = MLPProjModel( | |
cross_attention_dim=cross_attention_dim, | |
clip_embeddings_dim=clip_embeddings_dim, | |
) | |
else: | |
self.image_proj_model = Resampler( | |
dim=1280 if sdxl_plus else cross_attention_dim, | |
depth=4, | |
dim_head=64, | |
heads=20 if sdxl_plus else 12, | |
num_queries=self.clip_extra_context_tokens, | |
embedding_dim=clip_embeddings_dim, | |
output_dim=self.cross_attention_dim, | |
ff_mult=4, | |
) | |
else: | |
self.clip_extra_context_tokens = ( | |
state_dict["image_proj"]["proj.weight"].shape[0] | |
// self.cross_attention_dim | |
) | |
self.image_proj_model = ImageProjModel( | |
cross_attention_dim=self.cross_attention_dim, | |
clip_embeddings_dim=clip_embeddings_dim, | |
clip_extra_context_tokens=self.clip_extra_context_tokens, | |
) | |
self.image_proj_model.load_state_dict(state_dict["image_proj"]) | |
self.ip_layers = To_KV(state_dict["ip_adapter"]) | |
def init_proj_faceid(self): | |
if self.is_plus: | |
image_proj_model = ProjModelFaceIdPlus( | |
cross_attention_dim=self.cross_attention_dim, | |
id_embeddings_dim=512, | |
clip_embeddings_dim=self.clip_embeddings_dim, | |
num_tokens=4, | |
) | |
else: | |
image_proj_model = MLPProjModelFaceId( | |
cross_attention_dim=self.cross_attention_dim, | |
id_embeddings_dim=512, | |
num_tokens=self.clip_extra_context_tokens, | |
) | |
return image_proj_model | |
def init_proj_instantid(self, image_emb_dim=512, num_tokens=16): | |
image_proj_model = Resampler( | |
dim=1280, | |
depth=4, | |
dim_head=64, | |
heads=20, | |
num_queries=num_tokens, | |
embedding_dim=image_emb_dim, | |
output_dim=self.cross_attention_dim, | |
ff_mult=4, | |
) | |
return image_proj_model | |
def _get_image_embeds( | |
self, clip_vision_output: CLIPVisionModelOutput | |
) -> ImageEmbed: | |
self.image_proj_model.to(self.device) | |
if self.is_plus: | |
from annotator.clipvision import clip_vision_h_uc, clip_vision_vith_uc | |
cond = self.image_proj_model( | |
clip_vision_output["hidden_states"][-2].to( | |
device=self.device, dtype=torch.float32 | |
) | |
) | |
uncond = ( | |
clip_vision_vith_uc.to(cond) | |
if self.sdxl_plus | |
else self.image_proj_model(clip_vision_h_uc.to(cond)) | |
) | |
return ImageEmbed(cond, uncond) | |
clip_image_embeds = clip_vision_output["image_embeds"].to( | |
device=self.device, dtype=torch.float32 | |
) | |
image_prompt_embeds = self.image_proj_model(clip_image_embeds) | |
# input zero vector for unconditional. | |
uncond_image_prompt_embeds = self.image_proj_model( | |
torch.zeros_like(clip_image_embeds) | |
) | |
return ImageEmbed(image_prompt_embeds, uncond_image_prompt_embeds) | |
def _get_image_embeds_faceid_plus( | |
self, | |
face_embed: torch.Tensor, | |
clip_vision_output: CLIPVisionModelOutput, | |
is_v2: bool, | |
) -> ImageEmbed: | |
face_embed = face_embed.to(self.device, dtype=torch.float32) | |
from annotator.clipvision import clip_vision_h_uc | |
clip_embed = clip_vision_output["hidden_states"][-2].to( | |
device=self.device, dtype=torch.float32 | |
) | |
return ImageEmbed( | |
self.image_proj_model(face_embed, clip_embed, shortcut=is_v2), | |
self.image_proj_model( | |
torch.zeros_like(face_embed), | |
clip_vision_h_uc.to(clip_embed), | |
shortcut=is_v2, | |
), | |
) | |
def _get_image_embeds_faceid(self, insightface_output: torch.Tensor) -> ImageEmbed: | |
"""Get image embeds for non-plus faceid. Multiple inputs are supported.""" | |
self.image_proj_model.to(self.device) | |
faceid_embed = insightface_output.to(self.device, dtype=torch.float32) | |
return ImageEmbed( | |
self.image_proj_model(faceid_embed), | |
self.image_proj_model(torch.zeros_like(faceid_embed)), | |
) | |
def _get_image_embeds_instantid( | |
self, prompt_image_emb: Union[torch.Tensor, np.ndarray] | |
) -> ImageEmbed: | |
"""Get image embeds for instantid.""" | |
image_proj_model_in_features = 512 | |
if isinstance(prompt_image_emb, torch.Tensor): | |
prompt_image_emb = prompt_image_emb.clone().detach() | |
else: | |
prompt_image_emb = torch.tensor(prompt_image_emb) | |
prompt_image_emb = prompt_image_emb.to(device=self.device, dtype=torch.float32) | |
prompt_image_emb = prompt_image_emb.reshape( | |
[1, -1, image_proj_model_in_features] | |
) | |
return ImageEmbed( | |
self.image_proj_model(prompt_image_emb), | |
self.image_proj_model(torch.zeros_like(prompt_image_emb)), | |
) | |
def load(state_dict: dict, model_name: str) -> IPAdapterModel: | |
""" | |
Arguments: | |
- state_dict: model state_dict. | |
- model_name: file name of the model. | |
""" | |
is_v2 = "v2" in model_name | |
is_faceid = "faceid" in model_name | |
is_instantid = "instant_id" in model_name | |
is_portrait = "portrait" in model_name | |
is_full = "proj.3.weight" in state_dict["image_proj"] | |
is_plus = ( | |
is_full | |
or "latents" in state_dict["image_proj"] | |
or "perceiver_resampler.proj_in.weight" in state_dict["image_proj"] | |
) | |
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1] | |
sdxl = cross_attention_dim == 2048 | |
sdxl_plus = sdxl and is_plus | |
if is_instantid: | |
# InstantID does not use clip embedding. | |
clip_embeddings_dim = None | |
elif is_faceid: | |
if is_plus: | |
clip_embeddings_dim = 1280 | |
else: | |
# Plain faceid does not use clip_embeddings_dim. | |
clip_embeddings_dim = None | |
elif is_plus: | |
if sdxl_plus: | |
clip_embeddings_dim = int(state_dict["image_proj"]["latents"].shape[2]) | |
elif is_full: | |
clip_embeddings_dim = int( | |
state_dict["image_proj"]["proj.0.weight"].shape[1] | |
) | |
else: | |
clip_embeddings_dim = int( | |
state_dict["image_proj"]["proj_in.weight"].shape[1] | |
) | |
else: | |
clip_embeddings_dim = int(state_dict["image_proj"]["proj.weight"].shape[1]) | |
return IPAdapterModel( | |
state_dict, | |
clip_embeddings_dim=clip_embeddings_dim, | |
cross_attention_dim=cross_attention_dim, | |
is_plus=is_plus, | |
is_sdxl=sdxl, | |
sdxl_plus=sdxl_plus, | |
is_full=is_full, | |
is_faceid=is_faceid, | |
is_portrait=is_portrait, | |
is_instantid=is_instantid, | |
is_v2=is_v2, | |
) | |
def get_image_emb(self, preprocessor_output) -> ImageEmbed: | |
if self.is_instantid: | |
return self._get_image_embeds_instantid(preprocessor_output) | |
elif self.is_faceid and self.is_plus: | |
# Note: FaceID plus uses both face_embed and clip_embed. | |
# This should be the return value from preprocessor. | |
return self._get_image_embeds_faceid_plus( | |
preprocessor_output.face_embed, | |
preprocessor_output.clip_embed, | |
is_v2=self.is_v2, | |
) | |
elif self.is_faceid: | |
return self._get_image_embeds_faceid(preprocessor_output) | |
else: | |
return self._get_image_embeds(preprocessor_output) | |