Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import os
import sys
import torch
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from PIL import Image, ImageDraw, ImageFont
from einops import rearrange
from tools import *
import utils.transforms as data
from utils.seed import setup_seed
from tools.modules.config import cfg
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, DATASETS, AUTO_ENCODER
def test_enc_dec(gpu=0):
setup_seed(0)
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
save_dir = os.path.join('workspace/test_data/autoencoder', cfg.auto_encoder['type'])
os.system('rm -rf %s' % (save_dir))
os.makedirs(save_dir, exist_ok=True)
train_trans = data.Compose([
data.CenterCropWide(size=cfg.resolution),
data.ToTensor(),
data.Normalize(mean=cfg.mean, std=cfg.std)])
vit_trans = data.Compose([
data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])) if cfg.resolution[0]>cfg.vit_resolution[0] else data.CenterCropWide(size=cfg.vit_resolution),
data.Resize(cfg.vit_resolution),
data.ToTensor(),
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w
txt_size = cfg.resolution[1]
nc = int(38 * (txt_size / 256))
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)
dataset = DATASETS.build(cfg.vid_dataset, sample_fps=4, transforms=train_trans, vit_transforms=vit_trans)
print('There are %d videos' % (len(dataset)))
autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
autoencoder.eval() # freeze
for param in autoencoder.parameters():
param.requires_grad = False
autoencoder.to(gpu)
for idx, item in enumerate(dataset):
local_path = os.path.join(save_dir, '%04d.mp4' % idx)
# ref_frame, video_data, caption = item
ref_frame, vit_frame, video_data = item[:3]
video_data = video_data.to(gpu)
image_list = []
video_data_list = torch.chunk(video_data, video_data.shape[0]//cfg.chunk_size,dim=0)
with torch.no_grad():
decode_data = []
for chunk_data in video_data_list:
latent_z = autoencoder.encode_firsr_stage(chunk_data).detach()
# latent_z = get_first_stage_encoding(encoder_posterior).detach()
kwargs = {"timesteps": chunk_data.shape[0]}
recons_data = autoencoder.decode(latent_z, **kwargs)
vis_data = torch.cat([chunk_data, recons_data], dim=2).cpu()
vis_data = vis_data.mul_(video_std).add_(video_mean) # 8x3x16x256x384
vis_data = vis_data.cpu()
vis_data.clamp_(0, 1)
vis_data = vis_data.permute(0, 2, 3, 1)
vis_data = [(image.numpy() * 255).astype('uint8') for image in vis_data]
image_list.extend(vis_data)
num_image = len(image_list)
frame_dir = os.path.join(save_dir, 'temp')
os.makedirs(frame_dir, exist_ok=True)
for idx in range(num_image):
tpth = os.path.join(frame_dir, '%04d.png' % (idx+1))
cv2.imwrite(tpth, image_list[idx][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8 -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
os.system(cmd); os.system(f'rm -rf {frame_dir}')
if __name__ == '__main__':
test_enc_dec()