import timm import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from contextlib import nullcontext from unitok.vitamin import GeGluMlp, ViTaminDecoder from unitok.quant import VectorQuantizerM from unitok.vqvae import AttnProjection class UniTok(nn.Module): def __init__(self, args): super().__init__() self.num_query = args.num_query self.encoder = timm.create_model( args.model, patch_size=1, fc_norm=False, drop_rate=0.0, num_classes=0, global_pool='', pos_embed='none', class_token=False, mlp_layer=GeGluMlp, reg_tokens=args.num_query, img_size=args.img_size, drop_path_rate=args.drop_path, ) self.encoder.pos_embed = nn.Parameter(torch.zeros(1, 1, self.encoder.embed_dim), requires_grad=False) if args.quant_proj == 'linear': self.quant_proj = nn.Linear(self.encoder.embed_dim, args.vocab_width) elif args.quant_proj == 'attn': self.quant_proj = AttnProjection(self.encoder.embed_dim, args.vocab_width, self.encoder.embed_dim // args.vocab_width) else: raise NotImplementedError self.quantizer = VectorQuantizerM( vocab_size=args.vocab_size, vocab_width=args.vocab_width, beta=args.vq_beta, use_entropy_loss=args.le > 0, entropy_temp=args.e_temp, num_codebooks=args.num_codebooks, ) if args.quant_proj == 'linear': self.post_quant_proj = nn.Linear(args.vocab_width, self.encoder.embed_dim) elif args.quant_proj == 'attn': self.post_quant_proj = AttnProjection(args.vocab_width, self.encoder.embed_dim, self.encoder.embed_dim // args.vocab_width) else: raise NotImplementedError self.decoder = ViTaminDecoder( args.model, num_query=args.num_query, img_size=args.img_size, drop_path=args.drop_path, grad_ckpt=args.grad_ckpt, ) text_cfg = { "width": args.text_width, "heads": args.text_heads, "layers": args.text_layers, "vocab_size": args.text_vocab_size, "context_length": args.text_context_length, } from open_clip.model import _build_text_tower self.text_encoder = _build_text_tower(args.embed_dim, text_cfg) self.fc_norm = nn.LayerNorm(self.encoder.embed_dim, eps=1e-6) self.projection = nn.Linear(self.encoder.embed_dim, args.embed_dim) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.context_length = self.text_encoder.context_length self.vocab_size = self.text_encoder.vocab_size self.maybe_record_function = nullcontext self.text_no_grad = False self.encoder.set_grad_checkpointing(args.grad_ckpt) self.text_encoder.set_grad_checkpointing(args.grad_ckpt) def forward(self, img, vae_bs, text=None, ret_usages=False): img_tokens = self.encoder(img).float() with torch.cuda.amp.autocast(enabled=False): img_tokens = torch.utils.checkpoint.checkpoint(self.quant_proj, img_tokens, use_reentrant=False) img_tokens, vq_loss, entropy_loss, usages = self.quantizer(img_tokens) img_tokens = torch.utils.checkpoint.checkpoint(self.post_quant_proj, img_tokens, use_reentrant=False) img_rec = self.decoder(img_tokens[:vae_bs]).float() clip_visual = img_tokens.mean(dim=1) clip_visual = self.projection(self.fc_norm(clip_visual)) clip_visual = F.normalize(clip_visual, dim=-1) if text is not None: clip_text = self.text_encoder(text) clip_text = F.normalize(clip_text, dim=-1) else: clip_text = None output_dict = { "img_rec": img_rec, "vq_loss": vq_loss, "entropy_loss": entropy_loss, "codebook_usages": usages, "clip_image_features": clip_visual, "clip_text_features": clip_text, "logit_scale": self.logit_scale.exp() } return output_dict def encode_image(self, image, normalize: bool = False): img_tokens = self.encoder(image) img_tokens = self.quant_proj(img_tokens) img_indices = self.quantizer.f_to_idx(img_tokens) img_tokens = self.quantizer.idx_to_f(img_indices) img_tokens = self.post_quant_proj(img_tokens) features = img_tokens.mean(dim=1) features = self.projection(self.fc_norm(features)) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): features = self.text_encoder(text) return F.normalize(features, dim=-1) if normalize else features def img_to_idx(self, img): features = self.encoder(img).float() features = self.quant_proj(features) return self.quantizer.f_to_idx(features) def idx_to_img(self, indices): features = self.quantizer.idx_to_f(indices) features = self.post_quant_proj(features) img = self.decoder(features).clamp_(-1, 1) return img def img_to_reconstructed_img(self, image) -> torch.Tensor: img_tokens = self.encoder(image) img_tokens = self.quant_proj(img_tokens) img_tokens, _, _, _ = self.quantizer(img_tokens) img_tokens = self.post_quant_proj(img_tokens) img_rec = self.decoder(img_tokens).clamp_(-1, 1) return img_rec def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj=False): self.text.lock(unlocked_layers, freeze_layer_norm, unlock_text_proj) self.text_no_grad = True if __name__ == '__main__': model = timm.create_model( 'vitamin_base', patch_size=1, fc_norm=True, drop_rate=0.0, num_classes=0, global_pool='', pos_embed='none', class_token=False, mlp_layer=GeGluMlp, reg_tokens=0, img_size=256, drop_path_rate=0.1, ) model.pos_embed = nn.Parameter(torch.zeros(1, 1, model.embed_dim), requires_grad=False) model_dict = model.state_dict() ckpt_dict = torch.load('ViTamin-B/pytorch_model.bin') visual_dict = dict() for k, v in ckpt_dict.items(): if k.startswith('visual.'): if 'head' in k or 'pos_embed' in k: continue new_k = k.replace('visual.trunk.', '') visual_dict[new_k] = v model.load_state_dict(visual_dict, strict=False) print(set(model_dict.keys()) - set(visual_dict.keys())) print(set(visual_dict.keys() - set(model_dict.keys())))