|
""" |
|
This file is used to extract feature of the coco val set (to test zero-shot FID). |
|
""" |
|
|
|
import os |
|
import sys |
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
import torch |
|
import os |
|
import numpy as np |
|
from datasets import MSCOCODatabase |
|
import argparse |
|
from tqdm import tqdm |
|
|
|
import libs.autoencoder |
|
from libs.clip import FrozenCLIPEmbedder |
|
from libs.t5 import T5Embedder |
|
|
|
|
|
def main(resolution=256): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--split', default='val') |
|
args = parser.parse_args() |
|
print(args) |
|
|
|
if args.split == "val": |
|
datas = MSCOCODatabase(root='/data/qihao/dataset/coco2014/val2014', |
|
annFile='/data/qihao/dataset/coco2014/annotations/captions_val2014.json', |
|
size=resolution) |
|
save_dir = f'val' |
|
else: |
|
raise NotImplementedError |
|
|
|
device = "cuda" |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
autoencoder = libs.autoencoder.get_model('../assets/stable-diffusion/autoencoder_kl.pth') |
|
autoencoder.to(device) |
|
|
|
llm = 'clip' |
|
|
|
if llm=='clip': |
|
clip = FrozenCLIPEmbedder() |
|
clip.eval() |
|
clip.to(device) |
|
elif llm=='t5': |
|
t5 = T5Embedder(device=device) |
|
else: |
|
raise NotImplementedError |
|
|
|
with torch.no_grad(): |
|
for idx, data in tqdm(enumerate(datas)): |
|
x, captions = data |
|
|
|
if len(x.shape) == 3: |
|
x = x[None, ...] |
|
x = torch.tensor(x, device=device) |
|
moments = autoencoder(x, fn='encode_moments').squeeze(0) |
|
moments = moments.detach().cpu().numpy() |
|
np.save(os.path.join(save_dir, f'{idx}.npy'), moments) |
|
|
|
if llm=='clip': |
|
latent, latent_and_others = clip.encode(captions) |
|
token_embedding = latent_and_others['token_embedding'] |
|
token_mask = latent_and_others['token_mask'] |
|
token = latent_and_others['tokens'] |
|
elif llm=='t5': |
|
latent, latent_and_others = t5.get_text_embeddings(captions) |
|
token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0 |
|
token_mask = latent_and_others['token_mask'] |
|
token = latent_and_others['tokens'] |
|
|
|
for i in range(len(captions)): |
|
data = {'promt': captions[i], |
|
'token_embedding': token_embedding[i].detach().cpu().numpy(), |
|
'token_mask': token_mask[i].detach().cpu().numpy(), |
|
'token': token[i].detach().cpu().numpy()} |
|
np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), data) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|