Spaces:
Running
Running
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import os | |
import cv2 | |
import json | |
import torch | |
import random | |
import logging | |
import tempfile | |
import numpy as np | |
from copy import copy | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from utils.registry_class import DATASETS | |
class VideoDataset(Dataset): | |
def __init__(self, | |
data_list, | |
data_dir_list, | |
max_words=1000, | |
resolution=(384, 256), | |
vit_resolution=(224, 224), | |
max_frames=16, | |
sample_fps=8, | |
transforms=None, | |
vit_transforms=None, | |
get_first_frame=False, | |
**kwargs): | |
self.max_words = max_words | |
self.max_frames = max_frames | |
self.resolution = resolution | |
self.vit_resolution = vit_resolution | |
self.sample_fps = sample_fps | |
self.transforms = transforms | |
self.vit_transforms = vit_transforms | |
self.get_first_frame = get_first_frame | |
image_list = [] | |
for item_path, data_dir in zip(data_list, data_dir_list): | |
lines = open(item_path, 'r').readlines() | |
lines = [[data_dir, item] for item in lines] | |
image_list.extend(lines) | |
self.image_list = image_list | |
def __getitem__(self, index): | |
data_dir, file_path = self.image_list[index] | |
video_key = file_path.split('|||')[0] | |
try: | |
ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path) | |
except Exception as e: | |
logging.info('{} get frames failed... with error: {}'.format(video_key, e)) | |
caption = '' | |
video_key = '' | |
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0]) | |
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | |
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) | |
return ref_frame, vit_frame, video_data, caption, video_key | |
def _get_video_data(self, data_dir, file_path): | |
video_key, caption = file_path.split('|||') | |
file_path = os.path.join(data_dir, video_key) | |
for _ in range(5): | |
try: | |
capture = cv2.VideoCapture(file_path) | |
_fps = capture.get(cv2.CAP_PROP_FPS) | |
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) | |
stride = round(_fps / self.sample_fps) | |
cover_frame_num = (stride * self.max_frames) | |
if _total_frame_num < cover_frame_num + 5: | |
start_frame = 0 | |
end_frame = _total_frame_num | |
else: | |
start_frame = random.randint(0, _total_frame_num-cover_frame_num-5) | |
end_frame = start_frame + cover_frame_num | |
pointer, frame_list = 0, [] | |
while(True): | |
ret, frame = capture.read() | |
pointer +=1 | |
if (not ret) or (frame is None): break | |
if pointer < start_frame: continue | |
if pointer >= end_frame - 1: break | |
if (pointer - start_frame) % stride == 0: | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame = Image.fromarray(frame) | |
frame_list.append(frame) | |
break | |
except Exception as e: | |
logging.info('{} read video frame failed with error: {}'.format(video_key, e)) | |
continue | |
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) | |
if self.get_first_frame: | |
ref_idx = 0 | |
else: | |
ref_idx = int(len(frame_list)/2) | |
try: | |
if len(frame_list)>0: | |
mid_frame = copy(frame_list[ref_idx]) | |
vit_frame = self.vit_transforms(mid_frame) | |
frames = self.transforms(frame_list) | |
video_data[:len(frame_list), ...] = frames | |
else: | |
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | |
except: | |
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | |
ref_frame = copy(frames[ref_idx]) | |
return ref_frame, vit_frame, video_data, caption | |
def __len__(self): | |
return len(self.image_list) | |