Spaces:
Running
on
Zero
Running
on
Zero
import timm | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from contextlib import nullcontext | |
from unitok.vitamin import GeGluMlp, ViTaminDecoder | |
from unitok.quant import VectorQuantizerM | |
from unitok.vqvae import AttnProjection | |
class UniTok(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.num_query = args.num_query | |
self.encoder = timm.create_model( | |
args.model, | |
patch_size=1, | |
fc_norm=False, | |
drop_rate=0.0, | |
num_classes=0, | |
global_pool='', | |
pos_embed='none', | |
class_token=False, | |
mlp_layer=GeGluMlp, | |
reg_tokens=args.num_query, | |
img_size=args.img_size, | |
drop_path_rate=args.drop_path, | |
) | |
self.encoder.pos_embed = nn.Parameter(torch.zeros(1, 1, self.encoder.embed_dim), requires_grad=False) | |
if args.quant_proj == 'linear': | |
self.quant_proj = nn.Linear(self.encoder.embed_dim, args.vocab_width) | |
elif args.quant_proj == 'attn': | |
self.quant_proj = AttnProjection(self.encoder.embed_dim, args.vocab_width, self.encoder.embed_dim // args.vocab_width) | |
else: | |
raise NotImplementedError | |
self.quantizer = VectorQuantizerM( | |
vocab_size=args.vocab_size, | |
vocab_width=args.vocab_width, | |
beta=args.vq_beta, | |
use_entropy_loss=args.le > 0, | |
entropy_temp=args.e_temp, | |
num_codebooks=args.num_codebooks, | |
) | |
if args.quant_proj == 'linear': | |
self.post_quant_proj = nn.Linear(args.vocab_width, self.encoder.embed_dim) | |
elif args.quant_proj == 'attn': | |
self.post_quant_proj = AttnProjection(args.vocab_width, self.encoder.embed_dim, self.encoder.embed_dim // args.vocab_width) | |
else: | |
raise NotImplementedError | |
self.decoder = ViTaminDecoder( | |
args.model, | |
num_query=args.num_query, | |
img_size=args.img_size, | |
drop_path=args.drop_path, | |
grad_ckpt=args.grad_ckpt, | |
) | |
text_cfg = { | |
"width": args.text_width, | |
"heads": args.text_heads, | |
"layers": args.text_layers, | |
"vocab_size": args.text_vocab_size, | |
"context_length": args.text_context_length, | |
} | |
from open_clip.model import _build_text_tower | |
self.text_encoder = _build_text_tower(args.embed_dim, text_cfg) | |
self.fc_norm = nn.LayerNorm(self.encoder.embed_dim, eps=1e-6) | |
self.projection = nn.Linear(self.encoder.embed_dim, args.embed_dim) | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.context_length = self.text_encoder.context_length | |
self.vocab_size = self.text_encoder.vocab_size | |
self.maybe_record_function = nullcontext | |
self.text_no_grad = False | |
self.encoder.set_grad_checkpointing(args.grad_ckpt) | |
self.text_encoder.set_grad_checkpointing(args.grad_ckpt) | |
def forward(self, img, vae_bs, text=None, ret_usages=False): | |
img_tokens = self.encoder(img).float() | |
with torch.cuda.amp.autocast(enabled=False): | |
img_tokens = torch.utils.checkpoint.checkpoint(self.quant_proj, img_tokens, use_reentrant=False) | |
img_tokens, vq_loss, entropy_loss, usages = self.quantizer(img_tokens) | |
img_tokens = torch.utils.checkpoint.checkpoint(self.post_quant_proj, img_tokens, use_reentrant=False) | |
img_rec = self.decoder(img_tokens[:vae_bs]).float() | |
clip_visual = img_tokens.mean(dim=1) | |
clip_visual = self.projection(self.fc_norm(clip_visual)) | |
clip_visual = F.normalize(clip_visual, dim=-1) | |
if text is not None: | |
clip_text = self.text_encoder(text) | |
clip_text = F.normalize(clip_text, dim=-1) | |
else: | |
clip_text = None | |
output_dict = { | |
"img_rec": img_rec, | |
"vq_loss": vq_loss, | |
"entropy_loss": entropy_loss, | |
"codebook_usages": usages, | |
"clip_image_features": clip_visual, | |
"clip_text_features": clip_text, | |
"logit_scale": self.logit_scale.exp() | |
} | |
return output_dict | |
def encode_image(self, image, normalize: bool = False): | |
img_tokens = self.encoder(image) | |
img_tokens = self.quant_proj(img_tokens) | |
img_indices = self.quantizer.f_to_idx(img_tokens) | |
img_tokens = self.quantizer.idx_to_f(img_indices) | |
img_tokens = self.post_quant_proj(img_tokens) | |
features = img_tokens.mean(dim=1) | |
features = self.projection(self.fc_norm(features)) | |
return F.normalize(features, dim=-1) if normalize else features | |
def encode_text(self, text, normalize: bool = False): | |
features = self.text_encoder(text) | |
return F.normalize(features, dim=-1) if normalize else features | |
def img_to_idx(self, img): | |
features = self.encoder(img).float() | |
features = self.quant_proj(features) | |
return self.quantizer.f_to_idx(features) | |
def idx_to_img(self, indices): | |
features = self.quantizer.idx_to_f(indices) | |
features = self.post_quant_proj(features) | |
img = self.decoder(features).clamp_(-1, 1) | |
return img | |
def img_to_reconstructed_img(self, image) -> torch.Tensor: | |
img_tokens = self.encoder(image) | |
img_tokens = self.quant_proj(img_tokens) | |
img_tokens, _, _, _ = self.quantizer(img_tokens) | |
img_tokens = self.post_quant_proj(img_tokens) | |
img_rec = self.decoder(img_tokens).clamp_(-1, 1) | |
return img_rec | |
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj=False): | |
self.text.lock(unlocked_layers, freeze_layer_norm, unlock_text_proj) | |
self.text_no_grad = True | |
if __name__ == '__main__': | |
model = timm.create_model( | |
'vitamin_base', | |
patch_size=1, | |
fc_norm=True, | |
drop_rate=0.0, | |
num_classes=0, | |
global_pool='', | |
pos_embed='none', | |
class_token=False, | |
mlp_layer=GeGluMlp, | |
reg_tokens=0, | |
img_size=256, | |
drop_path_rate=0.1, | |
) | |
model.pos_embed = nn.Parameter(torch.zeros(1, 1, model.embed_dim), requires_grad=False) | |
model_dict = model.state_dict() | |
ckpt_dict = torch.load('ViTamin-B/pytorch_model.bin') | |
visual_dict = dict() | |
for k, v in ckpt_dict.items(): | |
if k.startswith('visual.'): | |
if 'head' in k or 'pos_embed' in k: | |
continue | |
new_k = k.replace('visual.trunk.', '') | |
visual_dict[new_k] = v | |
model.load_state_dict(visual_dict, strict=False) | |
print(set(model_dict.keys()) - set(visual_dict.keys())) | |
print(set(visual_dict.keys() - set(model_dict.keys()))) | |