ReLaX-VQA / demo_test_gpu.py
xinyiW915's picture
Upload 5 files
17f753b verified
raw
history blame contribute delete
13.2 kB
import argparse
import time
import math
import os
import shutil
from joblib import load
import cv2
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from thop import profile
from torchvision import models, transforms
from extractor.visualise_vit_layer import VitGenerator
from relax_vqa import get_deep_feature, process_video_feature, process_patches, get_frame_patches, flow_to_rgb, merge_fragments, concatenate_features
from extractor.vf_extract import process_video_residual
from model_regression import Mlp, preprocess_data
def fix_state_dict(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
name = k[7:]
elif k == 'n_averaged':
continue
else:
name = k
new_state_dict[name] = v
return new_state_dict
def preprocess_data(X, y=None, imp=None, scaler=None):
if not isinstance(X, torch.Tensor):
X = torch.tensor(X, device='cuda' if torch.cuda.is_available() else 'cpu')
X = torch.where(torch.isnan(X) | torch.isinf(X), torch.tensor(0.0, device=X.device), X)
if imp is not None or scaler is not None:
X_np = X.cpu().numpy()
if imp is not None:
X_np = imp.transform(X_np)
if scaler is not None:
X_np = scaler.transform(X_np)
X = torch.from_numpy(X_np).to(X.device)
if y is not None and y.size > 0:
if not isinstance(y, torch.Tensor):
y = torch.tensor(y, device=X.device)
y = y.reshape(-1).squeeze()
else:
y = None
return X, y, imp, scaler
def load_model(config, device, input_features=35203):
network_name = 'relaxvqa'
# input_features = X_test_processed.shape[1]
model = Mlp(input_features=input_features, out_features=1, drop_rate=0.2, act_layer=nn.GELU).to(device)
if config['is_finetune']:
model_path = os.path.join(config['save_path'], f"fine_tune_model/{config['video_type']}_{network_name}_{config['select_criteria']}_fine_tuned_model.pth")
else:
model_path = os.path.join(config['save_path'], f"{config['train_data_name']}_{network_name}_{config['select_criteria']}_trained_median_model_param_onLSVQ_TEST.pth")
print("Loading model from:", model_path)
state_dict = torch.load(model_path, map_location=device)
fixed_state_dict = fix_state_dict(state_dict)
try:
model.load_state_dict(fixed_state_dict)
except RuntimeError as e:
print(e)
return model
def evaluate_video_quality(config, resnet50, vit, model_mlp, device):
is_finetune = config['is_finetune']
save_path = config['save_path']
video_type = config['video_type']
video_name = config['video_name']
framerate = config['framerate']
sampled_fragment_path = os.path.join("../video_sampled_frame/sampled_frame/", "test_sampled_fragment")
video_path = config.get("video_path")
if video_path is None:
if video_type == 'youtube_ugc':
video_path = f'./ugc_original_videos/{video_name}.mkv'
else:
video_path = f'./ugc_original_videos/{video_name}.mp4'
target_size = 224
patch_size = 16
top_n = int((target_size / patch_size) * (target_size / patch_size))
# sampled video frames
start_time = time.time()
frames, frames_next = process_video_residual(video_type, video_name, framerate, video_path, sampled_fragment_path)
# get ResNet50 layer-stack features and ViT pooling features
all_frame_activations_resnet = []
all_frame_activations_vit = []
# get fragments ResNet50 features and ViT features
all_frame_activations_sampled_resnet = []
all_frame_activations_merged_resnet = []
all_frame_activations_sampled_vit = []
all_frame_activations_merged_vit = []
batch_size = 64 # Define the number of frames to process in parallel
for i in range(0, len(frames_next), batch_size):
batch_frames = frames[i:i + batch_size]
batch_rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in batch_frames]
batch_frames_next = frames_next[i:i + batch_size]
batch_tensors = torch.stack([transforms.ToTensor()(frame) for frame in batch_frames]).to(device)
batch_rgb_tensors = torch.stack([transforms.ToTensor()(frame_rgb) for frame_rgb in batch_rgb_frames]).to(device)
batch_tensors_next = torch.stack([transforms.ToTensor()(frame_next) for frame_next in batch_frames_next]).to(device)
# compute residuals
residuals = torch.abs(batch_tensors_next - batch_tensors)
# calculate optical flows
batch_gray_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in batch_frames]
batch_gray_frames_next = [cv2.cvtColor(frame_next, cv2.COLOR_BGR2GRAY) for frame_next in batch_frames_next]
batch_gray_frames = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames]
batch_gray_frames_next = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames_next]
flows = [cv2.calcOpticalFlowFarneback(batch_gray_frames[j], batch_gray_frames_next[j], None, 0.5, 3, 15, 3, 5, 1.2,0) for j in range(len(batch_gray_frames))]
for j in range(batch_tensors.size(0)):
'''sampled video frames'''
frame_tensor = batch_tensors[j].unsqueeze(0)
frame_rgb_tensor = batch_rgb_tensors[j].unsqueeze(0)
# frame_next_tensor = batch_tensors_next[j].unsqueeze(0)
frame_number = i + j + 1
# ResNet50 layer-stack features
activations_dict_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_rgb_tensor, frame_number, resnet50, device, 'layerstack')
all_frame_activations_resnet.append(activations_dict_resnet)
# ViT pooling features
activations_dict_vit, _, _ = get_deep_feature('vit', video_name, frame_rgb_tensor, frame_number, vit, device, 'pool')
all_frame_activations_vit.append(activations_dict_vit)
'''residual video frames'''
residual = residuals[j].unsqueeze(0)
flow = flows[j]
original_path = os.path.join(sampled_fragment_path, f'{video_name}_{frame_number}.png')
# Frame Differencing
residual_frag_path, diff_frag, positions = process_patches(original_path, 'frame_diff', residual, patch_size, target_size, top_n)
# Frame fragment
frame_patches = get_frame_patches(frame_tensor, positions, patch_size, target_size)
# Optical Flow
opticalflow_rgb = flow_to_rgb(flow)
opticalflow_rgb_tensor = transforms.ToTensor()(opticalflow_rgb).unsqueeze(0).to(device)
opticalflow_frag_path, flow_frag, _ = process_patches(original_path, 'optical_flow', opticalflow_rgb_tensor, patch_size, target_size, top_n)
merged_frag = merge_fragments(diff_frag, flow_frag)
# fragments ResNet50 features
sampled_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_patches, frame_number, resnet50, device, 'layerstack')
merged_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, merged_frag, frame_number, resnet50, device, 'pool')
all_frame_activations_sampled_resnet.append(sampled_frag_activations_resnet)
all_frame_activations_merged_resnet.append(merged_frag_activations_resnet)
# fragments ViT features
sampled_frag_activations_vit,_, _ = get_deep_feature('vit', video_name, frame_patches, frame_number, vit, device, 'pool')
merged_frag_activations_vit, _, _ = get_deep_feature('vit', video_name, merged_frag, frame_number, vit, device, 'pool')
all_frame_activations_sampled_vit.append(sampled_frag_activations_vit)
all_frame_activations_merged_vit.append(merged_frag_activations_vit)
print(f'video frame number: {len(all_frame_activations_resnet)}')
averaged_frames_resnet = process_video_feature(all_frame_activations_resnet, 'resnet50', 'layerstack')
averaged_frames_vit = process_video_feature(all_frame_activations_vit, 'vit', 'pool')
# print("ResNet50 layer-stacking feature shape:", averaged_frames_resnet.shape)
# print("ViT pooling feature shape:", averaged_frames_vit.shape)
averaged_frames_sampled_resnet = process_video_feature(all_frame_activations_sampled_resnet, 'resnet50', 'layerstack')
averaged_frames_merged_resnet = process_video_feature(all_frame_activations_merged_resnet, 'resnet50', 'pool')
averaged_combined_feature_resnet = concatenate_features(averaged_frames_sampled_resnet, averaged_frames_merged_resnet)
# print("Sampled fragments ResNet50 features shape:", averaged_frames_sampled_resnet.shape)
# print("Merged fragments ResNet50 features shape:", averaged_frames_merged_resnet.shape)
averaged_frames_sampled_vit = process_video_feature(all_frame_activations_sampled_vit, 'vit', 'pool')
averaged_frames_merged_vit = process_video_feature(all_frame_activations_merged_vit, 'vit', 'pool')
averaged_combined_feature_vit = concatenate_features(averaged_frames_sampled_vit, averaged_frames_merged_vit)
# print("Sampled fragments ViT features shape:", averaged_frames_sampled_vit.shape)
# print("Merged fragments ResNet50 features shape:", averaged_frames_merged_vit.shape)
# remove tmp folders
shutil.rmtree(sampled_fragment_path)
# concatenate features
combined_features = torch.cat([torch.mean(averaged_frames_resnet, dim=0), torch.mean(averaged_frames_vit, dim=0),
torch.mean(averaged_combined_feature_resnet, dim=0), torch.mean(averaged_combined_feature_vit, dim=0)], dim=0).view(1, -1)
imputer = load(f'{save_path}/scaler/{video_type}_imputer.pkl')
scaler = load(f'{save_path}/scaler/{video_type}_scaler.pkl')
X_test_processed, _, _, _ = preprocess_data(combined_features, None, imp=imputer, scaler=scaler)
feature_tensor = X_test_processed
# evaluation for test video
model_mlp.eval()
with torch.no_grad():
with torch.cuda.amp.autocast():
prediction = model_mlp(feature_tensor)
predicted_score = prediction.item()
# print(f"Raw Predicted Quality Score: {predicted_score}")
run_time = time.time() - start_time
if not is_finetune:
if video_type in ['konvid_1k', 'youtube_ugc']:
scaled_prediction = ((predicted_score - 1) / (99 / 4)) + 1.0
# print(f"Scaled Predicted Quality Score (1-5): {scaled_prediction}")
return scaled_prediction, run_time
else:
scaled_prediction = predicted_score
return scaled_prediction, run_time
else:
return predicted_score, run_time
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('-device', type=str, default='gpu', help='cpu or gpu')
parser.add_argument('-model_name', type=str, default='Mlp', help='Name of the regression model')
parser.add_argument('-select_criteria', type=str, default='byrmse', help='Selection criteria')
parser.add_argument('-train_data_name', type=str, default='lsvq_train', help='Name of the training data')
parser.add_argument('-is_finetune', type=bool, default=False, help='With or without finetune')
parser.add_argument('-save_path', type=str, default='model/', help='Path to save models')
parser.add_argument('-video_type', type=str, default='konvid_1k', help='Type of video')
parser.add_argument('-video_name', type=str, default='5636101558_540p', help='Name of the video')
parser.add_argument('-framerate', type=float, default=24, help='Frame rate of the video')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_arguments()
config = vars(args)
if config['device'] == "gpu":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cpu")
print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
# load models to device
resnet50 = models.resnet50(pretrained=True).to(device)
vit = VitGenerator('vit_base', 16, device, evaluate=True, random=False, verbose=True)
model_mlp = load_model(config, device)
total_time = 0
num_runs = 1
for i in range(num_runs):
quality_prediction, run_time = evaluate_video_quality(config, resnet50, vit, model_mlp, device)
print(f"Run {i + 1} - Time taken: {run_time:.4f} seconds")
total_time += run_time
average_time = total_time / num_runs
print(f"Average running time over {num_runs} runs: {average_time:.4f} seconds")
print("Predicted Quality Score:", quality_prediction)