File size: 3,775 Bytes
2ba4412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import sys
import torch
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from PIL import Image, ImageDraw, ImageFont

from einops import rearrange

from tools import *
import utils.transforms as data
from utils.seed import setup_seed
from tools.modules.config import cfg
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, DATASETS, AUTO_ENCODER


def test_enc_dec(gpu=0):
    setup_seed(0)
    cfg_update = pConfig(load=True)

    for k, v in cfg_update.cfg_dict.items():
        if isinstance(v, dict) and k in cfg:
            cfg[k].update(v)
        else:
            cfg[k] = v
    
    save_dir = os.path.join('workspace/test_data/autoencoder', cfg.auto_encoder['type'])
    os.system('rm -rf %s' % (save_dir))
    os.makedirs(save_dir, exist_ok=True)

    train_trans = data.Compose([
        data.CenterCropWide(size=cfg.resolution),
        data.ToTensor(),
        data.Normalize(mean=cfg.mean, std=cfg.std)])
    
    vit_trans = data.Compose([
        data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])) if cfg.resolution[0]>cfg.vit_resolution[0] else data.CenterCropWide(size=cfg.vit_resolution),
        data.Resize(cfg.vit_resolution),
        data.ToTensor(),
        data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])

    video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
    video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w

    txt_size = cfg.resolution[1]
    nc = int(38 * (txt_size / 256))
    font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)

    dataset = DATASETS.build(cfg.vid_dataset, sample_fps=4, transforms=train_trans, vit_transforms=vit_trans)
    print('There are %d videos' % (len(dataset)))

    autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
    autoencoder.eval() # freeze
    for param in autoencoder.parameters():
        param.requires_grad = False
    autoencoder.to(gpu)
    for idx, item in enumerate(dataset):
        local_path = os.path.join(save_dir, '%04d.mp4' % idx)
        # ref_frame, video_data, caption = item
        ref_frame, vit_frame, video_data = item[:3]
        video_data = video_data.to(gpu)

        image_list = []
        video_data_list = torch.chunk(video_data, video_data.shape[0]//cfg.chunk_size,dim=0)
        with torch.no_grad():
            decode_data = []
            for chunk_data in video_data_list:
                latent_z = autoencoder.encode_firsr_stage(chunk_data).detach()
                # latent_z = get_first_stage_encoding(encoder_posterior).detach()
                kwargs = {"timesteps": chunk_data.shape[0]}
                recons_data = autoencoder.decode(latent_z, **kwargs)

                vis_data = torch.cat([chunk_data, recons_data], dim=2).cpu()
                vis_data = vis_data.mul_(video_std).add_(video_mean)  # 8x3x16x256x384
                vis_data = vis_data.cpu()
                vis_data.clamp_(0, 1)
                vis_data = vis_data.permute(0, 2, 3, 1)
                vis_data = [(image.numpy() * 255).astype('uint8') for image in vis_data]
                image_list.extend(vis_data)
        
        num_image = len(image_list)
        frame_dir = os.path.join(save_dir, 'temp')
        os.makedirs(frame_dir, exist_ok=True)
        for idx in range(num_image):
            tpth = os.path.join(frame_dir, '%04d.png' % (idx+1))
            cv2.imwrite(tpth, image_list[idx][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
        cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8 -i {frame_dir}/%04d.png -vcodec libx264 -crf 17  -pix_fmt yuv420p {local_path}'
        os.system(cmd); os.system(f'rm -rf {frame_dir}')


if __name__ == '__main__':
    test_enc_dec()