Spaces:
Running
Running
File size: 4,445 Bytes
2ba4412 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
import cv2
import json
import torch
import random
import logging
import tempfile
import numpy as np
from copy import copy
from PIL import Image
from torch.utils.data import Dataset
from utils.registry_class import DATASETS
@DATASETS.register_class()
class VideoDataset(Dataset):
def __init__(self,
data_list,
data_dir_list,
max_words=1000,
resolution=(384, 256),
vit_resolution=(224, 224),
max_frames=16,
sample_fps=8,
transforms=None,
vit_transforms=None,
get_first_frame=False,
**kwargs):
self.max_words = max_words
self.max_frames = max_frames
self.resolution = resolution
self.vit_resolution = vit_resolution
self.sample_fps = sample_fps
self.transforms = transforms
self.vit_transforms = vit_transforms
self.get_first_frame = get_first_frame
image_list = []
for item_path, data_dir in zip(data_list, data_dir_list):
lines = open(item_path, 'r').readlines()
lines = [[data_dir, item] for item in lines]
image_list.extend(lines)
self.image_list = image_list
def __getitem__(self, index):
data_dir, file_path = self.image_list[index]
video_key = file_path.split('|||')[0]
try:
ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path)
except Exception as e:
logging.info('{} get frames failed... with error: {}'.format(video_key, e))
caption = ''
video_key = ''
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
return ref_frame, vit_frame, video_data, caption, video_key
def _get_video_data(self, data_dir, file_path):
video_key, caption = file_path.split('|||')
file_path = os.path.join(data_dir, video_key)
for _ in range(5):
try:
capture = cv2.VideoCapture(file_path)
_fps = capture.get(cv2.CAP_PROP_FPS)
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
stride = round(_fps / self.sample_fps)
cover_frame_num = (stride * self.max_frames)
if _total_frame_num < cover_frame_num + 5:
start_frame = 0
end_frame = _total_frame_num
else:
start_frame = random.randint(0, _total_frame_num-cover_frame_num-5)
end_frame = start_frame + cover_frame_num
pointer, frame_list = 0, []
while(True):
ret, frame = capture.read()
pointer +=1
if (not ret) or (frame is None): break
if pointer < start_frame: continue
if pointer >= end_frame - 1: break
if (pointer - start_frame) % stride == 0:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frame_list.append(frame)
break
except Exception as e:
logging.info('{} read video frame failed with error: {}'.format(video_key, e))
continue
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
if self.get_first_frame:
ref_idx = 0
else:
ref_idx = int(len(frame_list)/2)
try:
if len(frame_list)>0:
mid_frame = copy(frame_list[ref_idx])
vit_frame = self.vit_transforms(mid_frame)
frames = self.transforms(frame_list)
video_data[:len(frame_list), ...] = frames
else:
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
except:
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
ref_frame = copy(frames[ref_idx])
return ref_frame, vit_frame, video_data, caption
def __len__(self):
return len(self.image_list)
|