Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import os | |
import cv2 | |
import numpy as np | |
from extractor import visualise_resnet, visualise_resnet_layer, visualise_vit_layer | |
def get_deep_feature(network_name, video_name, frame, frame_number, model, device, layer_name): | |
if network_name == 'resnet50': | |
if layer_name == 'layerstack': | |
all_layers = ['resnet50.conv1', | |
'resnet50.layer1[0]', 'resnet50.layer1[1]', 'resnet50.layer1[2]', | |
'resnet50.layer2[0]', 'resnet50.layer2[1]', 'resnet50.layer2[2]', 'resnet50.layer2[3]', | |
'resnet50.layer3[0]', 'resnet50.layer3[1]', 'resnet50.layer3[2]', 'resnet50.layer3[3]', | |
'resnet50.layer4[0]', 'resnet50.layer4[1]', 'resnet50.layer4[2]'] | |
resnet50 = model | |
activations_dict, _, total_flops, total_params = visualise_resnet.process_video_frame(video_name, frame, frame_number, all_layers, resnet50, device) | |
elif layer_name == 'pool': | |
visual_layer = 'resnet50.avgpool' # before avg_pool | |
resnet50 = model | |
activations_dict, _, total_flops, total_params = visualise_resnet_layer.process_video_frame(video_name, frame, frame_number, visual_layer, resnet50, device) | |
elif network_name == 'vit': | |
patch_size = 16 | |
activations_dict, _, total_flops, total_params = visualise_vit_layer.process_video_frame(video_name, frame, frame_number, model, patch_size, device) | |
return activations_dict, total_flops, total_params | |
def process_video_feature(video_feature, network_name, layer_name): | |
# initialize an empty list to store processed frames | |
averaged_frames = [] | |
# iterate through each frame in the video_feature | |
for frame in video_feature: | |
frame_features = [] | |
if network_name == 'vit': | |
# global mean and std | |
global_mean = torch.mean(frame, dim=0) | |
global_max = torch.max(frame, dim=0)[0] | |
global_std = torch.std(frame, dim=0) | |
# concatenate all pooling | |
combined_features = torch.hstack([global_mean, global_max, global_std]) | |
frame_features.append(combined_features) | |
elif network_name == 'resnet50': | |
if layer_name == 'layerstack': | |
# iterate through each layer in the current framex | |
for layer_array in frame.values(): | |
# calculate the mean along the specified axes (1 and 2) for each layer | |
layer_mean = torch.mean(layer_array, dim=(1, 2)) | |
# append the calculated mean to the list for the current frame | |
frame_features.append(layer_mean) | |
elif layer_name == 'pool': | |
frame = torch.squeeze(torch.tensor(frame)) | |
# global mean and std | |
global_mean = torch.mean(frame, dim=0) | |
global_max = torch.max(frame, dim=0)[0] | |
global_std = torch.std(frame, dim=0) | |
# concatenate all pooling | |
combined_features = torch.hstack([frame, global_mean, global_max, global_std]) | |
frame_features.append(combined_features) | |
# concatenate the layer means horizontally to form the processed frame | |
processed_frame = torch.hstack(frame_features) | |
averaged_frames.append(processed_frame) | |
averaged_frames = torch.stack(averaged_frames) | |
return averaged_frames | |
def flow_to_rgb(flow): | |
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) | |
mag = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) | |
# convert angle to hue | |
hue = ang * 180 / np.pi / 2 | |
# create HSV | |
hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) | |
hsv[..., 0] = hue | |
hsv[..., 1] = 255 | |
hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) | |
# convert HSV to RGB | |
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) | |
return rgb | |
def get_patch_diff(residual_frame, patch_size): | |
h, w = residual_frame.shape[2:] # Assuming (1, C, H, W) shape | |
h_adj = (h // patch_size) * patch_size | |
w_adj = (w // patch_size) * patch_size | |
residual_frame_adj = residual_frame[:, :, :h_adj, :w_adj] | |
# calculate absolute patch difference | |
diff = torch.zeros((h_adj // patch_size, w_adj // patch_size), device=residual_frame.device) | |
for i in range(0, h_adj, patch_size): | |
for j in range(0, w_adj, patch_size): | |
patch = residual_frame_adj[:, :, i:i + patch_size, j:j + patch_size] | |
# absolute sum | |
diff[i // patch_size, j // patch_size] = torch.sum(torch.abs(patch)) | |
return diff | |
def extract_important_patches(residual_frame, diff, patch_size=16, target_size=224, top_n=196): | |
# find top n patches indices | |
patch_idx = torch.argsort(-diff.view(-1)) | |
top_patches = [(idx // diff.shape[1], idx % diff.shape[1]) for idx in patch_idx[:top_n]] | |
sorted_idx = sorted(top_patches, key=lambda x: (x[0], x[1])) | |
imp_patches_img = torch.zeros((residual_frame.shape[1], target_size, target_size), dtype=residual_frame.dtype, device=residual_frame.device) | |
patches_per_row = target_size // patch_size # 14 | |
# order the patch in the original location relation | |
positions = [] | |
for idx, (y, x) in enumerate(sorted_idx): | |
patch = residual_frame[:, :, y * patch_size:(y + 1) * patch_size, x * patch_size:(x + 1) * patch_size] | |
# new patch location | |
row_idx = idx // patches_per_row | |
col_idx = idx % patches_per_row | |
start_y = row_idx * patch_size | |
start_x = col_idx * patch_size | |
imp_patches_img[:, start_y:start_y + patch_size, start_x:start_x + patch_size] = patch | |
positions.append((y.item(), x.item())) | |
return imp_patches_img, positions | |
def get_frame_patches(frame, positions, patch_size, target_size): | |
imp_patches_img = torch.zeros((frame.shape[1], target_size, target_size), dtype=frame.dtype, device=frame.device) | |
patches_per_row = target_size // patch_size | |
for idx, (y, x) in enumerate(positions): | |
start_y = y * patch_size | |
start_x = x * patch_size | |
end_y = start_y + patch_size | |
end_x = start_x + patch_size | |
patch = frame[:, :, start_y:end_y, start_x:end_x] | |
row_idx = idx // patches_per_row | |
col_idx = idx % patches_per_row | |
target_start_y = row_idx * patch_size | |
target_start_x = col_idx * patch_size | |
imp_patches_img[:, target_start_y:target_start_y + patch_size, | |
target_start_x:target_start_x + patch_size] = patch.squeeze(0) | |
return imp_patches_img | |
def process_patches(original_path, frag_name, residual, patch_size, target_size, top_n): | |
diff = get_patch_diff(residual, patch_size) | |
imp_patches, positions = extract_important_patches(residual, diff, patch_size, target_size, top_n) | |
if frag_name == 'frame_diff': | |
frag_path = original_path.replace('.png', '_residual_imp.png') | |
elif frag_name == 'optical_flow': | |
frag_path = original_path.replace('.png', '_residual_of_imp.png') | |
# cv2.imwrite(frag_path, imp_patches) | |
return frag_path, imp_patches, positions | |
def merge_fragments(diff_fragment, flow_fragment): | |
alpha = 0.5 | |
merged_fragment = diff_fragment * alpha + flow_fragment * (1 - alpha) | |
return merged_fragment | |
def concatenate_features(frame_feature, residual_feature): | |
return torch.cat((frame_feature, residual_feature), dim=-1) | |