Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import time | |
import math | |
import os | |
import shutil | |
from joblib import load | |
import cv2 | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader, Dataset | |
from thop import profile | |
from torchvision import models, transforms | |
from extractor.visualise_vit_layer import VitGenerator | |
from relax_vqa import get_deep_feature, process_video_feature, process_patches, get_frame_patches, flow_to_rgb, merge_fragments, concatenate_features | |
from extractor.vf_extract import process_video_residual | |
from model_regression import Mlp, preprocess_data | |
def fix_state_dict(state_dict): | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
if k.startswith('module.'): | |
name = k[7:] | |
elif k == 'n_averaged': | |
continue | |
else: | |
name = k | |
new_state_dict[name] = v | |
return new_state_dict | |
def preprocess_data(X, y=None, imp=None, scaler=None): | |
if not isinstance(X, torch.Tensor): | |
X = torch.tensor(X, device='cuda' if torch.cuda.is_available() else 'cpu') | |
X = torch.where(torch.isnan(X) | torch.isinf(X), torch.tensor(0.0, device=X.device), X) | |
if imp is not None or scaler is not None: | |
X_np = X.cpu().numpy() | |
if imp is not None: | |
X_np = imp.transform(X_np) | |
if scaler is not None: | |
X_np = scaler.transform(X_np) | |
X = torch.from_numpy(X_np).to(X.device) | |
if y is not None and y.size > 0: | |
if not isinstance(y, torch.Tensor): | |
y = torch.tensor(y, device=X.device) | |
y = y.reshape(-1).squeeze() | |
else: | |
y = None | |
return X, y, imp, scaler | |
def load_model(config, device, input_features=35203): | |
network_name = 'relaxvqa' | |
# input_features = X_test_processed.shape[1] | |
model = Mlp(input_features=input_features, out_features=1, drop_rate=0.2, act_layer=nn.GELU).to(device) | |
if config['is_finetune']: | |
model_path = os.path.join(config['save_path'], f"fine_tune_model/{config['video_type']}_{network_name}_{config['select_criteria']}_fine_tuned_model.pth") | |
else: | |
model_path = os.path.join(config['save_path'], f"{config['train_data_name']}_{network_name}_{config['select_criteria']}_trained_median_model_param_onLSVQ_TEST.pth") | |
print("Loading model from:", model_path) | |
state_dict = torch.load(model_path, map_location=device) | |
fixed_state_dict = fix_state_dict(state_dict) | |
try: | |
model.load_state_dict(fixed_state_dict) | |
except RuntimeError as e: | |
print(e) | |
return model | |
def evaluate_video_quality(config, resnet50, vit, model_mlp, device): | |
is_finetune = config['is_finetune'] | |
save_path = config['save_path'] | |
video_type = config['video_type'] | |
video_name = config['video_name'] | |
framerate = config['framerate'] | |
sampled_fragment_path = os.path.join("../video_sampled_frame/sampled_frame/", "test_sampled_fragment") | |
video_path = config.get("video_path") | |
if video_path is None: | |
if video_type == 'youtube_ugc': | |
video_path = f'./ugc_original_videos/{video_name}.mkv' | |
else: | |
video_path = f'./ugc_original_videos/{video_name}.mp4' | |
target_size = 224 | |
patch_size = 16 | |
top_n = int((target_size / patch_size) * (target_size / patch_size)) | |
# sampled video frames | |
start_time = time.time() | |
frames, frames_next = process_video_residual(video_type, video_name, framerate, video_path, sampled_fragment_path) | |
# get ResNet50 layer-stack features and ViT pooling features | |
all_frame_activations_resnet = [] | |
all_frame_activations_vit = [] | |
# get fragments ResNet50 features and ViT features | |
all_frame_activations_sampled_resnet = [] | |
all_frame_activations_merged_resnet = [] | |
all_frame_activations_sampled_vit = [] | |
all_frame_activations_merged_vit = [] | |
batch_size = 64 # Define the number of frames to process in parallel | |
for i in range(0, len(frames_next), batch_size): | |
batch_frames = frames[i:i + batch_size] | |
batch_rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in batch_frames] | |
batch_frames_next = frames_next[i:i + batch_size] | |
batch_tensors = torch.stack([transforms.ToTensor()(frame) for frame in batch_frames]).to(device) | |
batch_rgb_tensors = torch.stack([transforms.ToTensor()(frame_rgb) for frame_rgb in batch_rgb_frames]).to(device) | |
batch_tensors_next = torch.stack([transforms.ToTensor()(frame_next) for frame_next in batch_frames_next]).to(device) | |
# compute residuals | |
residuals = torch.abs(batch_tensors_next - batch_tensors) | |
# calculate optical flows | |
batch_gray_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in batch_frames] | |
batch_gray_frames_next = [cv2.cvtColor(frame_next, cv2.COLOR_BGR2GRAY) for frame_next in batch_frames_next] | |
batch_gray_frames = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames] | |
batch_gray_frames_next = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames_next] | |
flows = [cv2.calcOpticalFlowFarneback(batch_gray_frames[j], batch_gray_frames_next[j], None, 0.5, 3, 15, 3, 5, 1.2,0) for j in range(len(batch_gray_frames))] | |
for j in range(batch_tensors.size(0)): | |
'''sampled video frames''' | |
frame_tensor = batch_tensors[j].unsqueeze(0) | |
frame_rgb_tensor = batch_rgb_tensors[j].unsqueeze(0) | |
# frame_next_tensor = batch_tensors_next[j].unsqueeze(0) | |
frame_number = i + j + 1 | |
# ResNet50 layer-stack features | |
activations_dict_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_rgb_tensor, frame_number, resnet50, device, 'layerstack') | |
all_frame_activations_resnet.append(activations_dict_resnet) | |
# ViT pooling features | |
activations_dict_vit, _, _ = get_deep_feature('vit', video_name, frame_rgb_tensor, frame_number, vit, device, 'pool') | |
all_frame_activations_vit.append(activations_dict_vit) | |
'''residual video frames''' | |
residual = residuals[j].unsqueeze(0) | |
flow = flows[j] | |
original_path = os.path.join(sampled_fragment_path, f'{video_name}_{frame_number}.png') | |
# Frame Differencing | |
residual_frag_path, diff_frag, positions = process_patches(original_path, 'frame_diff', residual, patch_size, target_size, top_n) | |
# Frame fragment | |
frame_patches = get_frame_patches(frame_tensor, positions, patch_size, target_size) | |
# Optical Flow | |
opticalflow_rgb = flow_to_rgb(flow) | |
opticalflow_rgb_tensor = transforms.ToTensor()(opticalflow_rgb).unsqueeze(0).to(device) | |
opticalflow_frag_path, flow_frag, _ = process_patches(original_path, 'optical_flow', opticalflow_rgb_tensor, patch_size, target_size, top_n) | |
merged_frag = merge_fragments(diff_frag, flow_frag) | |
# fragments ResNet50 features | |
sampled_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_patches, frame_number, resnet50, device, 'layerstack') | |
merged_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, merged_frag, frame_number, resnet50, device, 'pool') | |
all_frame_activations_sampled_resnet.append(sampled_frag_activations_resnet) | |
all_frame_activations_merged_resnet.append(merged_frag_activations_resnet) | |
# fragments ViT features | |
sampled_frag_activations_vit,_, _ = get_deep_feature('vit', video_name, frame_patches, frame_number, vit, device, 'pool') | |
merged_frag_activations_vit, _, _ = get_deep_feature('vit', video_name, merged_frag, frame_number, vit, device, 'pool') | |
all_frame_activations_sampled_vit.append(sampled_frag_activations_vit) | |
all_frame_activations_merged_vit.append(merged_frag_activations_vit) | |
print(f'video frame number: {len(all_frame_activations_resnet)}') | |
averaged_frames_resnet = process_video_feature(all_frame_activations_resnet, 'resnet50', 'layerstack') | |
averaged_frames_vit = process_video_feature(all_frame_activations_vit, 'vit', 'pool') | |
# print("ResNet50 layer-stacking feature shape:", averaged_frames_resnet.shape) | |
# print("ViT pooling feature shape:", averaged_frames_vit.shape) | |
averaged_frames_sampled_resnet = process_video_feature(all_frame_activations_sampled_resnet, 'resnet50', 'layerstack') | |
averaged_frames_merged_resnet = process_video_feature(all_frame_activations_merged_resnet, 'resnet50', 'pool') | |
averaged_combined_feature_resnet = concatenate_features(averaged_frames_sampled_resnet, averaged_frames_merged_resnet) | |
# print("Sampled fragments ResNet50 features shape:", averaged_frames_sampled_resnet.shape) | |
# print("Merged fragments ResNet50 features shape:", averaged_frames_merged_resnet.shape) | |
averaged_frames_sampled_vit = process_video_feature(all_frame_activations_sampled_vit, 'vit', 'pool') | |
averaged_frames_merged_vit = process_video_feature(all_frame_activations_merged_vit, 'vit', 'pool') | |
averaged_combined_feature_vit = concatenate_features(averaged_frames_sampled_vit, averaged_frames_merged_vit) | |
# print("Sampled fragments ViT features shape:", averaged_frames_sampled_vit.shape) | |
# print("Merged fragments ResNet50 features shape:", averaged_frames_merged_vit.shape) | |
# remove tmp folders | |
shutil.rmtree(sampled_fragment_path) | |
# concatenate features | |
combined_features = torch.cat([torch.mean(averaged_frames_resnet, dim=0), torch.mean(averaged_frames_vit, dim=0), | |
torch.mean(averaged_combined_feature_resnet, dim=0), torch.mean(averaged_combined_feature_vit, dim=0)], dim=0).view(1, -1) | |
imputer = load(f'{save_path}/scaler/{video_type}_imputer.pkl') | |
scaler = load(f'{save_path}/scaler/{video_type}_scaler.pkl') | |
X_test_processed, _, _, _ = preprocess_data(combined_features, None, imp=imputer, scaler=scaler) | |
feature_tensor = X_test_processed | |
# evaluation for test video | |
model_mlp.eval() | |
with torch.no_grad(): | |
with torch.cuda.amp.autocast(): | |
prediction = model_mlp(feature_tensor) | |
predicted_score = prediction.item() | |
# print(f"Raw Predicted Quality Score: {predicted_score}") | |
run_time = time.time() - start_time | |
if not is_finetune: | |
if video_type in ['konvid_1k', 'youtube_ugc']: | |
scaled_prediction = ((predicted_score - 1) / (99 / 4)) + 1.0 | |
# print(f"Scaled Predicted Quality Score (1-5): {scaled_prediction}") | |
return scaled_prediction, run_time | |
else: | |
scaled_prediction = predicted_score | |
return scaled_prediction, run_time | |
else: | |
return predicted_score, run_time | |
def parse_arguments(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-device', type=str, default='gpu', help='cpu or gpu') | |
parser.add_argument('-model_name', type=str, default='Mlp', help='Name of the regression model') | |
parser.add_argument('-select_criteria', type=str, default='byrmse', help='Selection criteria') | |
parser.add_argument('-train_data_name', type=str, default='lsvq_train', help='Name of the training data') | |
parser.add_argument('-is_finetune', type=bool, default=False, help='With or without finetune') | |
parser.add_argument('-save_path', type=str, default='model/', help='Path to save models') | |
parser.add_argument('-video_type', type=str, default='konvid_1k', help='Type of video') | |
parser.add_argument('-video_name', type=str, default='5636101558_540p', help='Name of the video') | |
parser.add_argument('-framerate', type=float, default=24, help='Frame rate of the video') | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
args = parse_arguments() | |
config = vars(args) | |
if config['device'] == "gpu": | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
device = torch.device("cpu") | |
print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}") | |
# load models to device | |
resnet50 = models.resnet50(pretrained=True).to(device) | |
vit = VitGenerator('vit_base', 16, device, evaluate=True, random=False, verbose=True) | |
model_mlp = load_model(config, device) | |
total_time = 0 | |
num_runs = 1 | |
for i in range(num_runs): | |
quality_prediction, run_time = evaluate_video_quality(config, resnet50, vit, model_mlp, device) | |
print(f"Run {i + 1} - Time taken: {run_time:.4f} seconds") | |
total_time += run_time | |
average_time = total_time / num_runs | |
print(f"Average running time over {num_runs} runs: {average_time:.4f} seconds") | |
print("Predicted Quality Score:", quality_prediction) |