UniTok / unitok /model.py
machuofan
init
7385f22
import timm
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from contextlib import nullcontext
from unitok.vitamin import GeGluMlp, ViTaminDecoder
from unitok.quant import VectorQuantizerM
from unitok.vqvae import AttnProjection
class UniTok(nn.Module):
def __init__(self, args):
super().__init__()
self.num_query = args.num_query
self.encoder = timm.create_model(
args.model,
patch_size=1,
fc_norm=False,
drop_rate=0.0,
num_classes=0,
global_pool='',
pos_embed='none',
class_token=False,
mlp_layer=GeGluMlp,
reg_tokens=args.num_query,
img_size=args.img_size,
drop_path_rate=args.drop_path,
)
self.encoder.pos_embed = nn.Parameter(torch.zeros(1, 1, self.encoder.embed_dim), requires_grad=False)
if args.quant_proj == 'linear':
self.quant_proj = nn.Linear(self.encoder.embed_dim, args.vocab_width)
elif args.quant_proj == 'attn':
self.quant_proj = AttnProjection(self.encoder.embed_dim, args.vocab_width, self.encoder.embed_dim // args.vocab_width)
else:
raise NotImplementedError
self.quantizer = VectorQuantizerM(
vocab_size=args.vocab_size,
vocab_width=args.vocab_width,
beta=args.vq_beta,
use_entropy_loss=args.le > 0,
entropy_temp=args.e_temp,
num_codebooks=args.num_codebooks,
)
if args.quant_proj == 'linear':
self.post_quant_proj = nn.Linear(args.vocab_width, self.encoder.embed_dim)
elif args.quant_proj == 'attn':
self.post_quant_proj = AttnProjection(args.vocab_width, self.encoder.embed_dim, self.encoder.embed_dim // args.vocab_width)
else:
raise NotImplementedError
self.decoder = ViTaminDecoder(
args.model,
num_query=args.num_query,
img_size=args.img_size,
drop_path=args.drop_path,
grad_ckpt=args.grad_ckpt,
)
text_cfg = {
"width": args.text_width,
"heads": args.text_heads,
"layers": args.text_layers,
"vocab_size": args.text_vocab_size,
"context_length": args.text_context_length,
}
from open_clip.model import _build_text_tower
self.text_encoder = _build_text_tower(args.embed_dim, text_cfg)
self.fc_norm = nn.LayerNorm(self.encoder.embed_dim, eps=1e-6)
self.projection = nn.Linear(self.encoder.embed_dim, args.embed_dim)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.context_length = self.text_encoder.context_length
self.vocab_size = self.text_encoder.vocab_size
self.maybe_record_function = nullcontext
self.text_no_grad = False
self.encoder.set_grad_checkpointing(args.grad_ckpt)
self.text_encoder.set_grad_checkpointing(args.grad_ckpt)
def forward(self, img, vae_bs, text=None, ret_usages=False):
img_tokens = self.encoder(img).float()
with torch.cuda.amp.autocast(enabled=False):
img_tokens = torch.utils.checkpoint.checkpoint(self.quant_proj, img_tokens, use_reentrant=False)
img_tokens, vq_loss, entropy_loss, usages = self.quantizer(img_tokens)
img_tokens = torch.utils.checkpoint.checkpoint(self.post_quant_proj, img_tokens, use_reentrant=False)
img_rec = self.decoder(img_tokens[:vae_bs]).float()
clip_visual = img_tokens.mean(dim=1)
clip_visual = self.projection(self.fc_norm(clip_visual))
clip_visual = F.normalize(clip_visual, dim=-1)
if text is not None:
clip_text = self.text_encoder(text)
clip_text = F.normalize(clip_text, dim=-1)
else:
clip_text = None
output_dict = {
"img_rec": img_rec,
"vq_loss": vq_loss,
"entropy_loss": entropy_loss,
"codebook_usages": usages,
"clip_image_features": clip_visual,
"clip_text_features": clip_text,
"logit_scale": self.logit_scale.exp()
}
return output_dict
def encode_image(self, image, normalize: bool = False):
img_tokens = self.encoder(image)
img_tokens = self.quant_proj(img_tokens)
img_indices = self.quantizer.f_to_idx(img_tokens)
img_tokens = self.quantizer.idx_to_f(img_indices)
img_tokens = self.post_quant_proj(img_tokens)
features = img_tokens.mean(dim=1)
features = self.projection(self.fc_norm(features))
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text_encoder(text)
return F.normalize(features, dim=-1) if normalize else features
def img_to_idx(self, img):
features = self.encoder(img).float()
features = self.quant_proj(features)
return self.quantizer.f_to_idx(features)
def idx_to_img(self, indices):
features = self.quantizer.idx_to_f(indices)
features = self.post_quant_proj(features)
img = self.decoder(features).clamp_(-1, 1)
return img
def img_to_reconstructed_img(self, image) -> torch.Tensor:
img_tokens = self.encoder(image)
img_tokens = self.quant_proj(img_tokens)
img_tokens, _, _, _ = self.quantizer(img_tokens)
img_tokens = self.post_quant_proj(img_tokens)
img_rec = self.decoder(img_tokens).clamp_(-1, 1)
return img_rec
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj=False):
self.text.lock(unlocked_layers, freeze_layer_norm, unlock_text_proj)
self.text_no_grad = True
if __name__ == '__main__':
model = timm.create_model(
'vitamin_base',
patch_size=1,
fc_norm=True,
drop_rate=0.0,
num_classes=0,
global_pool='',
pos_embed='none',
class_token=False,
mlp_layer=GeGluMlp,
reg_tokens=0,
img_size=256,
drop_path_rate=0.1,
)
model.pos_embed = nn.Parameter(torch.zeros(1, 1, model.embed_dim), requires_grad=False)
model_dict = model.state_dict()
ckpt_dict = torch.load('ViTamin-B/pytorch_model.bin')
visual_dict = dict()
for k, v in ckpt_dict.items():
if k.startswith('visual.'):
if 'head' in k or 'pos_embed' in k:
continue
new_k = k.replace('visual.trunk.', '')
visual_dict[new_k] = v
model.load_state_dict(visual_dict, strict=False)
print(set(model_dict.keys()) - set(visual_dict.keys()))
print(set(visual_dict.keys() - set(model_dict.keys())))