bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
from __future__ import annotations
from typing import List, NamedTuple, Tuple, Union
import torch
import torch.nn as nn
import numpy as np
from transformers.models.clip.modeling_clip import CLIPVisionModelOutput
from .image_proj_models import (
Resampler,
ImageProjModel,
MLPProjModel,
MLPProjModelFaceId,
ProjModelFaceIdPlus,
)
class ImageEmbed(NamedTuple):
"""Image embed for a single image."""
cond_emb: torch.Tensor
uncond_emb: torch.Tensor
def eval(self, cond_mark: torch.Tensor) -> torch.Tensor:
assert cond_mark.ndim == 4
assert self.cond_emb.ndim == self.uncond_emb.ndim == 3
assert (
self.uncond_emb.shape[0] == 1
or self.cond_emb.shape[0] == self.uncond_emb.shape[0]
)
assert (
self.cond_emb.shape[0] == 1 or self.cond_emb.shape[0] == cond_mark.shape[0]
)
cond_mark = cond_mark[:, :, :, 0].to(self.cond_emb)
device = cond_mark.device
dtype = cond_mark.dtype
return self.cond_emb.to(
device=device, dtype=dtype
) * cond_mark + self.uncond_emb.to(device=device, dtype=dtype) * (1 - cond_mark)
def average_of(*args: List[Tuple[torch.Tensor, torch.Tensor]]) -> "ImageEmbed":
conds, unconds = zip(*args)
def average_tensors(tensors: List[torch.Tensor]) -> torch.Tensor:
return torch.sum(torch.stack(tensors), dim=0) / len(tensors)
return ImageEmbed(average_tensors(conds), average_tensors(unconds))
class To_KV(torch.nn.Module):
def __init__(self, state_dict):
super().__init__()
self.to_kvs = nn.ModuleDict()
for key, value in state_dict.items():
k = key.replace(".weight", "").replace(".", "_")
self.to_kvs[k] = nn.Linear(value.shape[1], value.shape[0], bias=False)
self.to_kvs[k].weight.data = value
class IPAdapterModel(torch.nn.Module):
def __init__(
self,
state_dict,
clip_embeddings_dim,
cross_attention_dim,
is_plus,
is_sdxl: bool,
sdxl_plus,
is_full,
is_faceid: bool,
is_portrait: bool,
is_instantid: bool,
is_v2: bool,
):
super().__init__()
self.device = "cpu"
self.clip_embeddings_dim = clip_embeddings_dim
self.cross_attention_dim = cross_attention_dim
self.is_plus = is_plus
self.is_sdxl = is_sdxl
self.sdxl_plus = sdxl_plus
self.is_full = is_full
self.is_v2 = is_v2
self.is_faceid = is_faceid
self.is_instantid = is_instantid
self.clip_extra_context_tokens = 16 if (self.is_plus or is_portrait) else 4
if is_instantid:
self.image_proj_model = self.init_proj_instantid()
elif is_faceid:
self.image_proj_model = self.init_proj_faceid()
elif self.is_plus:
if self.is_full:
self.image_proj_model = MLPProjModel(
cross_attention_dim=cross_attention_dim,
clip_embeddings_dim=clip_embeddings_dim,
)
else:
self.image_proj_model = Resampler(
dim=1280 if sdxl_plus else cross_attention_dim,
depth=4,
dim_head=64,
heads=20 if sdxl_plus else 12,
num_queries=self.clip_extra_context_tokens,
embedding_dim=clip_embeddings_dim,
output_dim=self.cross_attention_dim,
ff_mult=4,
)
else:
self.clip_extra_context_tokens = (
state_dict["image_proj"]["proj.weight"].shape[0]
// self.cross_attention_dim
)
self.image_proj_model = ImageProjModel(
cross_attention_dim=self.cross_attention_dim,
clip_embeddings_dim=clip_embeddings_dim,
clip_extra_context_tokens=self.clip_extra_context_tokens,
)
self.image_proj_model.load_state_dict(state_dict["image_proj"])
self.ip_layers = To_KV(state_dict["ip_adapter"])
def init_proj_faceid(self):
if self.is_plus:
image_proj_model = ProjModelFaceIdPlus(
cross_attention_dim=self.cross_attention_dim,
id_embeddings_dim=512,
clip_embeddings_dim=self.clip_embeddings_dim,
num_tokens=4,
)
else:
image_proj_model = MLPProjModelFaceId(
cross_attention_dim=self.cross_attention_dim,
id_embeddings_dim=512,
num_tokens=self.clip_extra_context_tokens,
)
return image_proj_model
def init_proj_instantid(self, image_emb_dim=512, num_tokens=16):
image_proj_model = Resampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=num_tokens,
embedding_dim=image_emb_dim,
output_dim=self.cross_attention_dim,
ff_mult=4,
)
return image_proj_model
@torch.inference_mode()
def _get_image_embeds(
self, clip_vision_output: CLIPVisionModelOutput
) -> ImageEmbed:
self.image_proj_model.to(self.device)
if self.is_plus:
from annotator.clipvision import clip_vision_h_uc, clip_vision_vith_uc
cond = self.image_proj_model(
clip_vision_output["hidden_states"][-2].to(
device=self.device, dtype=torch.float32
)
)
uncond = (
clip_vision_vith_uc.to(cond)
if self.sdxl_plus
else self.image_proj_model(clip_vision_h_uc.to(cond))
)
return ImageEmbed(cond, uncond)
clip_image_embeds = clip_vision_output["image_embeds"].to(
device=self.device, dtype=torch.float32
)
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
# input zero vector for unconditional.
uncond_image_prompt_embeds = self.image_proj_model(
torch.zeros_like(clip_image_embeds)
)
return ImageEmbed(image_prompt_embeds, uncond_image_prompt_embeds)
@torch.inference_mode()
def _get_image_embeds_faceid_plus(
self,
face_embed: torch.Tensor,
clip_vision_output: CLIPVisionModelOutput,
is_v2: bool,
) -> ImageEmbed:
face_embed = face_embed.to(self.device, dtype=torch.float32)
from annotator.clipvision import clip_vision_h_uc
clip_embed = clip_vision_output["hidden_states"][-2].to(
device=self.device, dtype=torch.float32
)
return ImageEmbed(
self.image_proj_model(face_embed, clip_embed, shortcut=is_v2),
self.image_proj_model(
torch.zeros_like(face_embed),
clip_vision_h_uc.to(clip_embed),
shortcut=is_v2,
),
)
@torch.inference_mode()
def _get_image_embeds_faceid(self, insightface_output: torch.Tensor) -> ImageEmbed:
"""Get image embeds for non-plus faceid. Multiple inputs are supported."""
self.image_proj_model.to(self.device)
faceid_embed = insightface_output.to(self.device, dtype=torch.float32)
return ImageEmbed(
self.image_proj_model(faceid_embed),
self.image_proj_model(torch.zeros_like(faceid_embed)),
)
@torch.inference_mode()
def _get_image_embeds_instantid(
self, prompt_image_emb: Union[torch.Tensor, np.ndarray]
) -> ImageEmbed:
"""Get image embeds for instantid."""
image_proj_model_in_features = 512
if isinstance(prompt_image_emb, torch.Tensor):
prompt_image_emb = prompt_image_emb.clone().detach()
else:
prompt_image_emb = torch.tensor(prompt_image_emb)
prompt_image_emb = prompt_image_emb.to(device=self.device, dtype=torch.float32)
prompt_image_emb = prompt_image_emb.reshape(
[1, -1, image_proj_model_in_features]
)
return ImageEmbed(
self.image_proj_model(prompt_image_emb),
self.image_proj_model(torch.zeros_like(prompt_image_emb)),
)
@staticmethod
def load(state_dict: dict, model_name: str) -> IPAdapterModel:
"""
Arguments:
- state_dict: model state_dict.
- model_name: file name of the model.
"""
is_v2 = "v2" in model_name
is_faceid = "faceid" in model_name
is_instantid = "instant_id" in model_name
is_portrait = "portrait" in model_name
is_full = "proj.3.weight" in state_dict["image_proj"]
is_plus = (
is_full
or "latents" in state_dict["image_proj"]
or "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]
)
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]
sdxl = cross_attention_dim == 2048
sdxl_plus = sdxl and is_plus
if is_instantid:
# InstantID does not use clip embedding.
clip_embeddings_dim = None
elif is_faceid:
if is_plus:
clip_embeddings_dim = 1280
else:
# Plain faceid does not use clip_embeddings_dim.
clip_embeddings_dim = None
elif is_plus:
if sdxl_plus:
clip_embeddings_dim = int(state_dict["image_proj"]["latents"].shape[2])
elif is_full:
clip_embeddings_dim = int(
state_dict["image_proj"]["proj.0.weight"].shape[1]
)
else:
clip_embeddings_dim = int(
state_dict["image_proj"]["proj_in.weight"].shape[1]
)
else:
clip_embeddings_dim = int(state_dict["image_proj"]["proj.weight"].shape[1])
return IPAdapterModel(
state_dict,
clip_embeddings_dim=clip_embeddings_dim,
cross_attention_dim=cross_attention_dim,
is_plus=is_plus,
is_sdxl=sdxl,
sdxl_plus=sdxl_plus,
is_full=is_full,
is_faceid=is_faceid,
is_portrait=is_portrait,
is_instantid=is_instantid,
is_v2=is_v2,
)
def get_image_emb(self, preprocessor_output) -> ImageEmbed:
if self.is_instantid:
return self._get_image_embeds_instantid(preprocessor_output)
elif self.is_faceid and self.is_plus:
# Note: FaceID plus uses both face_embed and clip_embed.
# This should be the return value from preprocessor.
return self._get_image_embeds_faceid_plus(
preprocessor_output.face_embed,
preprocessor_output.clip_embed,
is_v2=self.is_v2,
)
elif self.is_faceid:
return self._get_image_embeds_faceid(preprocessor_output)
else:
return self._get_image_embeds(preprocessor_output)