|
""" |
|
This file is used to extract feature for visulization during training |
|
""" |
|
|
|
import os |
|
import sys |
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
import torch |
|
import os |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
import libs.autoencoder |
|
from libs.clip import FrozenCLIPEmbedder |
|
from libs.t5 import T5Embedder |
|
|
|
|
|
def main(): |
|
prompts = [ |
|
'A road with traffic lights, street lights and cars.', |
|
'A bus driving in a city area with traffic signs.', |
|
'A bus pulls over to the curb close to an intersection.', |
|
'A group of people are walking and one is holding an umbrella.', |
|
'A baseball player taking a swing at an incoming ball.', |
|
'A dog next to a white cat with black-tipped ears.', |
|
'A tiger standing on a rooftop while singing and jamming on an electric guitar under a spotlight. anime illustration.', |
|
'A bird wearing headphones and speaking into a high-end microphone in a recording studio.', |
|
'A bus made of cardboard.', |
|
'A tower in the mountains.', |
|
'Two cups of coffee, one with latte art of a cat. The other has latter art of a bird.', |
|
'Oil painting of a robot made of sushi, holding chopsticks.', |
|
'Portrait of a dog wearing a hat and holding a flag that has a yin-yang symbol on it.', |
|
'A teddy bear wearing a motorcycle helmet and cape is standing in front of Loch Awe with Kilchurn Castle behind him. dslr photo.', |
|
'A man standing on the moon', |
|
] |
|
save_dir = f'run_vis' |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
device = 'cuda' |
|
llm = 'clip' |
|
|
|
if llm=='clip': |
|
clip = FrozenCLIPEmbedder() |
|
clip.eval() |
|
clip.to(device) |
|
elif llm=='t5': |
|
t5 = T5Embedder(device=device) |
|
else: |
|
raise NotImplementedError |
|
|
|
if llm=='clip': |
|
latent, latent_and_others = clip.encode(prompts) |
|
token_embedding = latent_and_others['token_embedding'] |
|
token_mask = latent_and_others['token_mask'] |
|
token = latent_and_others['tokens'] |
|
elif llm=='t5': |
|
latent, latent_and_others = t5.get_text_embeddings(prompts) |
|
token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0 |
|
token_mask = latent_and_others['token_mask'] |
|
token = latent_and_others['tokens'] |
|
|
|
for i in range(len(prompts)): |
|
data = {'promt': prompts[i], |
|
'token_embedding': token_embedding[i].detach().cpu().numpy(), |
|
'token_mask': token_mask[i].detach().cpu().numpy(), |
|
'token': token[i].detach().cpu().numpy()} |
|
np.save(os.path.join(save_dir, f'{i}.npy'), data) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|