File size: 4,445 Bytes
2ba4412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import cv2
import json
import torch
import random
import logging
import tempfile
import numpy as np
from copy import copy
from PIL import Image
from torch.utils.data import Dataset
from utils.registry_class import DATASETS


@DATASETS.register_class()
class VideoDataset(Dataset):
    def __init__(self, 
            data_list,
            data_dir_list,
            max_words=1000,
            resolution=(384, 256),
            vit_resolution=(224, 224),
            max_frames=16,
            sample_fps=8,
            transforms=None,
            vit_transforms=None,
            get_first_frame=False,
            **kwargs):

        self.max_words = max_words
        self.max_frames = max_frames
        self.resolution = resolution
        self.vit_resolution = vit_resolution
        self.sample_fps = sample_fps
        self.transforms = transforms
        self.vit_transforms = vit_transforms
        self.get_first_frame = get_first_frame
        
        image_list = []
        for item_path, data_dir in zip(data_list, data_dir_list):
            lines = open(item_path, 'r').readlines()
            lines = [[data_dir, item] for item in lines]
            image_list.extend(lines)
        self.image_list = image_list


    def __getitem__(self, index):
        data_dir, file_path = self.image_list[index]
        video_key = file_path.split('|||')[0]
        try:
            ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path)
        except Exception as e:
            logging.info('{} get frames failed... with error: {}'.format(video_key, e))
            caption = ''
            video_key = '' 
            ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
            vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
            video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])        
        return ref_frame, vit_frame, video_data, caption, video_key
    
    
    def _get_video_data(self, data_dir, file_path):
        video_key, caption = file_path.split('|||')
        file_path = os.path.join(data_dir, video_key)

        for _ in range(5):
            try:
                capture = cv2.VideoCapture(file_path)
                _fps = capture.get(cv2.CAP_PROP_FPS)
                _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
                stride = round(_fps / self.sample_fps)
                cover_frame_num = (stride * self.max_frames)
                if _total_frame_num < cover_frame_num + 5:
                    start_frame = 0
                    end_frame = _total_frame_num
                else:
                    start_frame = random.randint(0, _total_frame_num-cover_frame_num-5)
                    end_frame = start_frame + cover_frame_num
                
                pointer, frame_list = 0, []
                while(True):
                    ret, frame = capture.read()
                    pointer +=1 
                    if (not ret) or (frame is None): break
                    if pointer < start_frame: continue
                    if pointer >= end_frame - 1: break
                    if (pointer - start_frame) % stride == 0:
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        frame = Image.fromarray(frame)
                        frame_list.append(frame)
                break
            except Exception as e:
                logging.info('{} read video frame failed with error: {}'.format(video_key, e))
                continue

        video_data = torch.zeros(self.max_frames, 3,  self.resolution[1], self.resolution[0])
        if self.get_first_frame:
            ref_idx = 0
        else:
            ref_idx = int(len(frame_list)/2)
        try:
            if len(frame_list)>0:
                mid_frame = copy(frame_list[ref_idx])
                vit_frame = self.vit_transforms(mid_frame)
                frames = self.transforms(frame_list)
                video_data[:len(frame_list), ...] = frames
            else:
                vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
        except:
            vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
        ref_frame = copy(frames[ref_idx])
        
        return ref_frame, vit_frame, video_data, caption

    def __len__(self):
        return len(self.image_list)