File size: 13,157 Bytes
17f753b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import argparse
import time
import math
import os
import shutil
from joblib import load
import cv2
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from thop import profile
from torchvision import models, transforms

from extractor.visualise_vit_layer import VitGenerator
from relax_vqa import get_deep_feature, process_video_feature, process_patches, get_frame_patches, flow_to_rgb, merge_fragments, concatenate_features
from extractor.vf_extract import process_video_residual
from model_regression import Mlp, preprocess_data

def fix_state_dict(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:]
        elif k == 'n_averaged':
            continue
        else:
            name = k
        new_state_dict[name] = v
    return new_state_dict

def preprocess_data(X, y=None, imp=None, scaler=None):
    if not isinstance(X, torch.Tensor):
        X = torch.tensor(X, device='cuda' if torch.cuda.is_available() else 'cpu')
    X = torch.where(torch.isnan(X) | torch.isinf(X), torch.tensor(0.0, device=X.device), X)

    if imp is not None or scaler is not None:
        X_np = X.cpu().numpy()
        if imp is not None:
            X_np = imp.transform(X_np)
        if scaler is not None:
            X_np = scaler.transform(X_np)
        X = torch.from_numpy(X_np).to(X.device)

    if y is not None and y.size > 0:
        if not isinstance(y, torch.Tensor):
            y = torch.tensor(y, device=X.device)
        y = y.reshape(-1).squeeze()
    else:
        y = None

    return X, y, imp, scaler

def load_model(config, device, input_features=35203):
    network_name = 'relaxvqa'
    # input_features = X_test_processed.shape[1]
    model = Mlp(input_features=input_features, out_features=1, drop_rate=0.2, act_layer=nn.GELU).to(device)
    if config['is_finetune']:
        model_path = os.path.join(config['save_path'], f"fine_tune_model/{config['video_type']}_{network_name}_{config['select_criteria']}_fine_tuned_model.pth")
    else:
        model_path = os.path.join(config['save_path'], f"{config['train_data_name']}_{network_name}_{config['select_criteria']}_trained_median_model_param_onLSVQ_TEST.pth")
    print("Loading model from:", model_path)
    state_dict = torch.load(model_path, map_location=device)
    fixed_state_dict = fix_state_dict(state_dict)
    try:
        model.load_state_dict(fixed_state_dict)
    except RuntimeError as e:
        print(e)
    return model

def evaluate_video_quality(config, resnet50, vit, model_mlp, device):
    is_finetune = config['is_finetune']
    save_path = config['save_path']
    video_type = config['video_type']
    video_name = config['video_name']
    framerate = config['framerate']
    sampled_fragment_path = os.path.join("../video_sampled_frame/sampled_frame/", "test_sampled_fragment")

    video_path = config.get("video_path")
    if video_path is None:
        if video_type == 'youtube_ugc':
            video_path = f'./ugc_original_videos/{video_name}.mkv'
        else:
            video_path = f'./ugc_original_videos/{video_name}.mp4'
    target_size = 224
    patch_size = 16
    top_n = int((target_size / patch_size) * (target_size / patch_size))

    # sampled video frames
    start_time = time.time()
    frames, frames_next = process_video_residual(video_type, video_name, framerate, video_path, sampled_fragment_path)

    # get ResNet50 layer-stack features and ViT pooling features
    all_frame_activations_resnet = []
    all_frame_activations_vit = []
    # get fragments ResNet50 features and ViT features
    all_frame_activations_sampled_resnet = []
    all_frame_activations_merged_resnet = []
    all_frame_activations_sampled_vit = []
    all_frame_activations_merged_vit = []

    batch_size = 64  # Define the number of frames to process in parallel
    for i in range(0, len(frames_next), batch_size):
        batch_frames = frames[i:i + batch_size]
        batch_rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in batch_frames]
        batch_frames_next = frames_next[i:i + batch_size]
        batch_tensors = torch.stack([transforms.ToTensor()(frame) for frame in batch_frames]).to(device)
        batch_rgb_tensors = torch.stack([transforms.ToTensor()(frame_rgb) for frame_rgb in batch_rgb_frames]).to(device)
        batch_tensors_next = torch.stack([transforms.ToTensor()(frame_next) for frame_next in batch_frames_next]).to(device)

        # compute residuals
        residuals = torch.abs(batch_tensors_next - batch_tensors)

        # calculate optical flows
        batch_gray_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in batch_frames]
        batch_gray_frames_next = [cv2.cvtColor(frame_next, cv2.COLOR_BGR2GRAY) for frame_next in batch_frames_next]
        batch_gray_frames = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames]
        batch_gray_frames_next = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames_next]
        flows = [cv2.calcOpticalFlowFarneback(batch_gray_frames[j], batch_gray_frames_next[j], None, 0.5, 3, 15, 3, 5, 1.2,0) for j in range(len(batch_gray_frames))]

        for j in range(batch_tensors.size(0)):
            '''sampled video frames'''
            frame_tensor = batch_tensors[j].unsqueeze(0)
            frame_rgb_tensor = batch_rgb_tensors[j].unsqueeze(0)
            # frame_next_tensor = batch_tensors_next[j].unsqueeze(0)
            frame_number = i + j + 1

            # ResNet50 layer-stack features
            activations_dict_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_rgb_tensor, frame_number, resnet50, device, 'layerstack')
            all_frame_activations_resnet.append(activations_dict_resnet)
            # ViT pooling features
            activations_dict_vit, _, _ = get_deep_feature('vit', video_name, frame_rgb_tensor, frame_number, vit, device, 'pool')
            all_frame_activations_vit.append(activations_dict_vit)


            '''residual video frames'''
            residual = residuals[j].unsqueeze(0)
            flow = flows[j]
            original_path = os.path.join(sampled_fragment_path, f'{video_name}_{frame_number}.png')

            # Frame Differencing
            residual_frag_path, diff_frag, positions = process_patches(original_path, 'frame_diff', residual, patch_size, target_size, top_n)
            # Frame fragment
            frame_patches = get_frame_patches(frame_tensor, positions, patch_size, target_size)
            # Optical Flow
            opticalflow_rgb = flow_to_rgb(flow)
            opticalflow_rgb_tensor = transforms.ToTensor()(opticalflow_rgb).unsqueeze(0).to(device)
            opticalflow_frag_path, flow_frag, _ = process_patches(original_path, 'optical_flow', opticalflow_rgb_tensor, patch_size, target_size, top_n)

            merged_frag = merge_fragments(diff_frag, flow_frag)

            # fragments ResNet50 features
            sampled_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_patches, frame_number, resnet50, device, 'layerstack')
            merged_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, merged_frag, frame_number, resnet50, device, 'pool')
            all_frame_activations_sampled_resnet.append(sampled_frag_activations_resnet)
            all_frame_activations_merged_resnet.append(merged_frag_activations_resnet)
            # fragments ViT features
            sampled_frag_activations_vit,_, _ = get_deep_feature('vit', video_name, frame_patches, frame_number, vit, device, 'pool')
            merged_frag_activations_vit, _, _ = get_deep_feature('vit', video_name, merged_frag, frame_number, vit, device, 'pool')
            all_frame_activations_sampled_vit.append(sampled_frag_activations_vit)
            all_frame_activations_merged_vit.append(merged_frag_activations_vit)

    print(f'video frame number: {len(all_frame_activations_resnet)}')
    averaged_frames_resnet = process_video_feature(all_frame_activations_resnet, 'resnet50', 'layerstack')
    averaged_frames_vit = process_video_feature(all_frame_activations_vit, 'vit', 'pool')
    # print("ResNet50 layer-stacking feature shape:", averaged_frames_resnet.shape)
    # print("ViT pooling feature shape:", averaged_frames_vit.shape)
    averaged_frames_sampled_resnet = process_video_feature(all_frame_activations_sampled_resnet, 'resnet50', 'layerstack')
    averaged_frames_merged_resnet = process_video_feature(all_frame_activations_merged_resnet, 'resnet50', 'pool')
    averaged_combined_feature_resnet = concatenate_features(averaged_frames_sampled_resnet, averaged_frames_merged_resnet)
    # print("Sampled fragments ResNet50 features shape:", averaged_frames_sampled_resnet.shape)
    # print("Merged fragments ResNet50 features shape:", averaged_frames_merged_resnet.shape)
    averaged_frames_sampled_vit = process_video_feature(all_frame_activations_sampled_vit, 'vit', 'pool')
    averaged_frames_merged_vit = process_video_feature(all_frame_activations_merged_vit, 'vit', 'pool')
    averaged_combined_feature_vit = concatenate_features(averaged_frames_sampled_vit, averaged_frames_merged_vit)
    # print("Sampled fragments ViT features shape:", averaged_frames_sampled_vit.shape)
    # print("Merged fragments ResNet50 features shape:", averaged_frames_merged_vit.shape)

    # remove tmp folders
    shutil.rmtree(sampled_fragment_path)

    # concatenate features
    combined_features = torch.cat([torch.mean(averaged_frames_resnet, dim=0), torch.mean(averaged_frames_vit, dim=0),
                                   torch.mean(averaged_combined_feature_resnet, dim=0), torch.mean(averaged_combined_feature_vit, dim=0)], dim=0).view(1, -1)
    imputer = load(f'{save_path}/scaler/{video_type}_imputer.pkl')
    scaler = load(f'{save_path}/scaler/{video_type}_scaler.pkl')
    X_test_processed, _, _, _ = preprocess_data(combined_features, None, imp=imputer, scaler=scaler)
    feature_tensor = X_test_processed

    # evaluation for test video
    model_mlp.eval()

    with torch.no_grad():
        with torch.cuda.amp.autocast():
            prediction = model_mlp(feature_tensor)
            predicted_score = prediction.item()
            # print(f"Raw Predicted Quality Score: {predicted_score}")
            run_time = time.time() - start_time

            if not is_finetune:
                if video_type in ['konvid_1k', 'youtube_ugc']:
                    scaled_prediction = ((predicted_score - 1) / (99 / 4)) + 1.0
                    # print(f"Scaled Predicted Quality Score (1-5): {scaled_prediction}")
                    return scaled_prediction, run_time
                else:
                    scaled_prediction = predicted_score
                    return scaled_prediction, run_time
            else:
                return predicted_score, run_time

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('-device', type=str, default='gpu', help='cpu or gpu')
    parser.add_argument('-model_name', type=str, default='Mlp', help='Name of the regression model')
    parser.add_argument('-select_criteria', type=str, default='byrmse', help='Selection criteria')
    parser.add_argument('-train_data_name', type=str, default='lsvq_train', help='Name of the training data')
    parser.add_argument('-is_finetune', type=bool, default=False, help='With or without finetune')
    parser.add_argument('-save_path', type=str, default='model/', help='Path to save models')
    parser.add_argument('-video_type', type=str, default='konvid_1k', help='Type of video')
    parser.add_argument('-video_name', type=str, default='5636101558_540p', help='Name of the video')
    parser.add_argument('-framerate', type=float, default=24, help='Frame rate of the video')

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_arguments()
    config = vars(args)
    if config['device'] == "gpu":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")
    print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")

    # load models to device
    resnet50 = models.resnet50(pretrained=True).to(device)
    vit = VitGenerator('vit_base', 16, device, evaluate=True, random=False, verbose=True)
    model_mlp = load_model(config, device)

    total_time = 0
    num_runs = 1
    for i in range(num_runs):
        quality_prediction, run_time = evaluate_video_quality(config, resnet50, vit, model_mlp, device)
        print(f"Run {i + 1} - Time taken: {run_time:.4f} seconds")

        total_time += run_time
    average_time = total_time / num_runs

    print(f"Average running time over {num_runs} runs: {average_time:.4f} seconds")
    print("Predicted Quality Score:", quality_prediction)