Spaces:
Running
Running
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import os | |
import sys | |
import torch | |
import imageio | |
import numpy as np | |
import os.path as osp | |
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) | |
from PIL import Image, ImageDraw, ImageFont | |
from einops import rearrange | |
from tools import * | |
import utils.transforms as data | |
from utils.seed import setup_seed | |
from tools.modules.config import cfg | |
from utils.config import Config as pConfig | |
from utils.registry_class import ENGINE, DATASETS, AUTO_ENCODER | |
def test_enc_dec(gpu=0): | |
setup_seed(0) | |
cfg_update = pConfig(load=True) | |
for k, v in cfg_update.cfg_dict.items(): | |
if isinstance(v, dict) and k in cfg: | |
cfg[k].update(v) | |
else: | |
cfg[k] = v | |
save_dir = os.path.join('workspace/test_data/autoencoder', cfg.auto_encoder['type']) | |
os.system('rm -rf %s' % (save_dir)) | |
os.makedirs(save_dir, exist_ok=True) | |
train_trans = data.Compose([ | |
data.CenterCropWide(size=cfg.resolution), | |
data.ToTensor(), | |
data.Normalize(mean=cfg.mean, std=cfg.std)]) | |
vit_trans = data.Compose([ | |
data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])) if cfg.resolution[0]>cfg.vit_resolution[0] else data.CenterCropWide(size=cfg.vit_resolution), | |
data.Resize(cfg.vit_resolution), | |
data.ToTensor(), | |
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) | |
video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w | |
video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w | |
txt_size = cfg.resolution[1] | |
nc = int(38 * (txt_size / 256)) | |
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13) | |
dataset = DATASETS.build(cfg.vid_dataset, sample_fps=4, transforms=train_trans, vit_transforms=vit_trans) | |
print('There are %d videos' % (len(dataset))) | |
autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) | |
autoencoder.eval() # freeze | |
for param in autoencoder.parameters(): | |
param.requires_grad = False | |
autoencoder.to(gpu) | |
for idx, item in enumerate(dataset): | |
local_path = os.path.join(save_dir, '%04d.mp4' % idx) | |
# ref_frame, video_data, caption = item | |
ref_frame, vit_frame, video_data = item[:3] | |
video_data = video_data.to(gpu) | |
image_list = [] | |
video_data_list = torch.chunk(video_data, video_data.shape[0]//cfg.chunk_size,dim=0) | |
with torch.no_grad(): | |
decode_data = [] | |
for chunk_data in video_data_list: | |
latent_z = autoencoder.encode_firsr_stage(chunk_data).detach() | |
# latent_z = get_first_stage_encoding(encoder_posterior).detach() | |
kwargs = {"timesteps": chunk_data.shape[0]} | |
recons_data = autoencoder.decode(latent_z, **kwargs) | |
vis_data = torch.cat([chunk_data, recons_data], dim=2).cpu() | |
vis_data = vis_data.mul_(video_std).add_(video_mean) # 8x3x16x256x384 | |
vis_data = vis_data.cpu() | |
vis_data.clamp_(0, 1) | |
vis_data = vis_data.permute(0, 2, 3, 1) | |
vis_data = [(image.numpy() * 255).astype('uint8') for image in vis_data] | |
image_list.extend(vis_data) | |
num_image = len(image_list) | |
frame_dir = os.path.join(save_dir, 'temp') | |
os.makedirs(frame_dir, exist_ok=True) | |
for idx in range(num_image): | |
tpth = os.path.join(frame_dir, '%04d.png' % (idx+1)) | |
cv2.imwrite(tpth, image_list[idx][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8 -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' | |
os.system(cmd); os.system(f'rm -rf {frame_dir}') | |
if __name__ == '__main__': | |
test_enc_dec() | |