|
""" |
|
This file is used to extract feature of the empty prompt. |
|
""" |
|
|
|
import os |
|
import sys |
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
import torch |
|
import os |
|
import numpy as np |
|
from libs.clip import FrozenCLIPEmbedder |
|
from libs.t5 import T5Embedder |
|
|
|
|
|
def main(): |
|
prompts = [ |
|
'', |
|
] |
|
|
|
device = 'cuda' |
|
llm = 'clip' |
|
|
|
if llm=='clip': |
|
clip = FrozenCLIPEmbedder() |
|
clip.eval() |
|
clip.to(device) |
|
elif llm=='t5': |
|
t5 = T5Embedder(device=device) |
|
else: |
|
raise NotImplementedError |
|
|
|
save_dir = f'./' |
|
|
|
if llm=='clip': |
|
latent, latent_and_others = clip.encode(prompts) |
|
token_embedding = latent_and_others['token_embedding'] |
|
token_mask = latent_and_others['token_mask'] |
|
token = latent_and_others['tokens'] |
|
elif llm=='t5': |
|
latent, latent_and_others = t5.get_text_embeddings(prompts) |
|
token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0 |
|
token_mask = latent_and_others['token_mask'] |
|
token = latent_and_others['tokens'] |
|
|
|
for i in range(len(prompts)): |
|
data = {'token_embedding': token_embedding[i].detach().cpu().numpy(), |
|
'token_mask': token_mask[i].detach().cpu().numpy(), |
|
'token': token[i].detach().cpu().numpy(), |
|
'batch_caption': prompts[i]} |
|
np.save(os.path.join(save_dir, f'empty_context.npy'), data) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|