ReLaX-VQA / relax_vqa.py
xinyiW915's picture
Upload 5 files
17f753b verified
raw
history blame contribute delete
7.58 kB
import torch
import os
import cv2
import numpy as np
from extractor import visualise_resnet, visualise_resnet_layer, visualise_vit_layer
def get_deep_feature(network_name, video_name, frame, frame_number, model, device, layer_name):
if network_name == 'resnet50':
if layer_name == 'layerstack':
all_layers = ['resnet50.conv1',
'resnet50.layer1[0]', 'resnet50.layer1[1]', 'resnet50.layer1[2]',
'resnet50.layer2[0]', 'resnet50.layer2[1]', 'resnet50.layer2[2]', 'resnet50.layer2[3]',
'resnet50.layer3[0]', 'resnet50.layer3[1]', 'resnet50.layer3[2]', 'resnet50.layer3[3]',
'resnet50.layer4[0]', 'resnet50.layer4[1]', 'resnet50.layer4[2]']
resnet50 = model
activations_dict, _, total_flops, total_params = visualise_resnet.process_video_frame(video_name, frame, frame_number, all_layers, resnet50, device)
elif layer_name == 'pool':
visual_layer = 'resnet50.avgpool' # before avg_pool
resnet50 = model
activations_dict, _, total_flops, total_params = visualise_resnet_layer.process_video_frame(video_name, frame, frame_number, visual_layer, resnet50, device)
elif network_name == 'vit':
patch_size = 16
activations_dict, _, total_flops, total_params = visualise_vit_layer.process_video_frame(video_name, frame, frame_number, model, patch_size, device)
return activations_dict, total_flops, total_params
def process_video_feature(video_feature, network_name, layer_name):
# initialize an empty list to store processed frames
averaged_frames = []
# iterate through each frame in the video_feature
for frame in video_feature:
frame_features = []
if network_name == 'vit':
# global mean and std
global_mean = torch.mean(frame, dim=0)
global_max = torch.max(frame, dim=0)[0]
global_std = torch.std(frame, dim=0)
# concatenate all pooling
combined_features = torch.hstack([global_mean, global_max, global_std])
frame_features.append(combined_features)
elif network_name == 'resnet50':
if layer_name == 'layerstack':
# iterate through each layer in the current framex
for layer_array in frame.values():
# calculate the mean along the specified axes (1 and 2) for each layer
layer_mean = torch.mean(layer_array, dim=(1, 2))
# append the calculated mean to the list for the current frame
frame_features.append(layer_mean)
elif layer_name == 'pool':
frame = torch.squeeze(torch.tensor(frame))
# global mean and std
global_mean = torch.mean(frame, dim=0)
global_max = torch.max(frame, dim=0)[0]
global_std = torch.std(frame, dim=0)
# concatenate all pooling
combined_features = torch.hstack([frame, global_mean, global_max, global_std])
frame_features.append(combined_features)
# concatenate the layer means horizontally to form the processed frame
processed_frame = torch.hstack(frame_features)
averaged_frames.append(processed_frame)
averaged_frames = torch.stack(averaged_frames)
return averaged_frames
def flow_to_rgb(flow):
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
mag = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
# convert angle to hue
hue = ang * 180 / np.pi / 2
# create HSV
hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
hsv[..., 0] = hue
hsv[..., 1] = 255
hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
# convert HSV to RGB
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return rgb
def get_patch_diff(residual_frame, patch_size):
h, w = residual_frame.shape[2:] # Assuming (1, C, H, W) shape
h_adj = (h // patch_size) * patch_size
w_adj = (w // patch_size) * patch_size
residual_frame_adj = residual_frame[:, :, :h_adj, :w_adj]
# calculate absolute patch difference
diff = torch.zeros((h_adj // patch_size, w_adj // patch_size), device=residual_frame.device)
for i in range(0, h_adj, patch_size):
for j in range(0, w_adj, patch_size):
patch = residual_frame_adj[:, :, i:i + patch_size, j:j + patch_size]
# absolute sum
diff[i // patch_size, j // patch_size] = torch.sum(torch.abs(patch))
return diff
def extract_important_patches(residual_frame, diff, patch_size=16, target_size=224, top_n=196):
# find top n patches indices
patch_idx = torch.argsort(-diff.view(-1))
top_patches = [(idx // diff.shape[1], idx % diff.shape[1]) for idx in patch_idx[:top_n]]
sorted_idx = sorted(top_patches, key=lambda x: (x[0], x[1]))
imp_patches_img = torch.zeros((residual_frame.shape[1], target_size, target_size), dtype=residual_frame.dtype, device=residual_frame.device)
patches_per_row = target_size // patch_size # 14
# order the patch in the original location relation
positions = []
for idx, (y, x) in enumerate(sorted_idx):
patch = residual_frame[:, :, y * patch_size:(y + 1) * patch_size, x * patch_size:(x + 1) * patch_size]
# new patch location
row_idx = idx // patches_per_row
col_idx = idx % patches_per_row
start_y = row_idx * patch_size
start_x = col_idx * patch_size
imp_patches_img[:, start_y:start_y + patch_size, start_x:start_x + patch_size] = patch
positions.append((y.item(), x.item()))
return imp_patches_img, positions
def get_frame_patches(frame, positions, patch_size, target_size):
imp_patches_img = torch.zeros((frame.shape[1], target_size, target_size), dtype=frame.dtype, device=frame.device)
patches_per_row = target_size // patch_size
for idx, (y, x) in enumerate(positions):
start_y = y * patch_size
start_x = x * patch_size
end_y = start_y + patch_size
end_x = start_x + patch_size
patch = frame[:, :, start_y:end_y, start_x:end_x]
row_idx = idx // patches_per_row
col_idx = idx % patches_per_row
target_start_y = row_idx * patch_size
target_start_x = col_idx * patch_size
imp_patches_img[:, target_start_y:target_start_y + patch_size,
target_start_x:target_start_x + patch_size] = patch.squeeze(0)
return imp_patches_img
def process_patches(original_path, frag_name, residual, patch_size, target_size, top_n):
diff = get_patch_diff(residual, patch_size)
imp_patches, positions = extract_important_patches(residual, diff, patch_size, target_size, top_n)
if frag_name == 'frame_diff':
frag_path = original_path.replace('.png', '_residual_imp.png')
elif frag_name == 'optical_flow':
frag_path = original_path.replace('.png', '_residual_of_imp.png')
# cv2.imwrite(frag_path, imp_patches)
return frag_path, imp_patches, positions
def merge_fragments(diff_fragment, flow_fragment):
alpha = 0.5
merged_fragment = diff_fragment * alpha + flow_fragment * (1 - alpha)
return merged_fragment
def concatenate_features(frame_feature, residual_feature):
return torch.cat((frame_feature, residual_feature), dim=-1)