File size: 2,704 Bytes
f9567e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
This file is used to extract feature for visulization during training
"""

import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import torch
import os
import numpy as np
from tqdm import tqdm

import libs.autoencoder
from libs.clip import FrozenCLIPEmbedder
from libs.t5 import T5Embedder


def main():
    prompts = [
        'A road with traffic lights, street lights and cars.',
        'A bus driving in a city area with traffic signs.',
        'A bus pulls over to the curb close to an intersection.',
        'A group of people are walking and one is holding an umbrella.',
        'A baseball player taking a swing at an incoming ball.',
        'A dog next to a white cat with black-tipped ears.',
        'A tiger standing on a rooftop while singing and jamming on an electric guitar under a spotlight. anime illustration.',
        'A bird wearing headphones and speaking into a high-end microphone in a recording studio.',
        'A bus made of cardboard.',
        'A tower in the mountains.',
        'Two cups of coffee, one with latte art of a cat. The other has latter art of a bird.',
        'Oil painting of a robot made of sushi, holding chopsticks.',
        'Portrait of a dog wearing a hat and holding a flag that has a yin-yang symbol on it.',
        'A teddy bear wearing a motorcycle helmet and cape is standing in front of Loch Awe with Kilchurn Castle behind him. dslr photo.',
        'A man standing on the moon',
    ]
    save_dir = f'run_vis'
    os.makedirs(save_dir, exist_ok=True)

    device = 'cuda'
    llm = 'clip'

    if llm=='clip':
        clip = FrozenCLIPEmbedder()
        clip.eval()
        clip.to(device)
    elif llm=='t5':
        t5 = T5Embedder(device=device)
    else:
        raise NotImplementedError

    if llm=='clip':
        latent, latent_and_others = clip.encode(prompts)
        token_embedding = latent_and_others['token_embedding']
        token_mask = latent_and_others['token_mask']
        token = latent_and_others['tokens']
    elif llm=='t5':
        latent, latent_and_others = t5.get_text_embeddings(prompts)
        token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0
        token_mask = latent_and_others['token_mask']
        token = latent_and_others['tokens']

    for i in range(len(prompts)):
        data = {'promt': prompts[i],
                'token_embedding': token_embedding[i].detach().cpu().numpy(), 
                'token_mask': token_mask[i].detach().cpu().numpy(),
                'token': token[i].detach().cpu().numpy()}
        np.save(os.path.join(save_dir, f'{i}.npy'), data)


if __name__ == '__main__':
    main()