Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,579 Bytes
17f753b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
import os
import cv2
import numpy as np
from extractor import visualise_resnet, visualise_resnet_layer, visualise_vit_layer
def get_deep_feature(network_name, video_name, frame, frame_number, model, device, layer_name):
if network_name == 'resnet50':
if layer_name == 'layerstack':
all_layers = ['resnet50.conv1',
'resnet50.layer1[0]', 'resnet50.layer1[1]', 'resnet50.layer1[2]',
'resnet50.layer2[0]', 'resnet50.layer2[1]', 'resnet50.layer2[2]', 'resnet50.layer2[3]',
'resnet50.layer3[0]', 'resnet50.layer3[1]', 'resnet50.layer3[2]', 'resnet50.layer3[3]',
'resnet50.layer4[0]', 'resnet50.layer4[1]', 'resnet50.layer4[2]']
resnet50 = model
activations_dict, _, total_flops, total_params = visualise_resnet.process_video_frame(video_name, frame, frame_number, all_layers, resnet50, device)
elif layer_name == 'pool':
visual_layer = 'resnet50.avgpool' # before avg_pool
resnet50 = model
activations_dict, _, total_flops, total_params = visualise_resnet_layer.process_video_frame(video_name, frame, frame_number, visual_layer, resnet50, device)
elif network_name == 'vit':
patch_size = 16
activations_dict, _, total_flops, total_params = visualise_vit_layer.process_video_frame(video_name, frame, frame_number, model, patch_size, device)
return activations_dict, total_flops, total_params
def process_video_feature(video_feature, network_name, layer_name):
# initialize an empty list to store processed frames
averaged_frames = []
# iterate through each frame in the video_feature
for frame in video_feature:
frame_features = []
if network_name == 'vit':
# global mean and std
global_mean = torch.mean(frame, dim=0)
global_max = torch.max(frame, dim=0)[0]
global_std = torch.std(frame, dim=0)
# concatenate all pooling
combined_features = torch.hstack([global_mean, global_max, global_std])
frame_features.append(combined_features)
elif network_name == 'resnet50':
if layer_name == 'layerstack':
# iterate through each layer in the current framex
for layer_array in frame.values():
# calculate the mean along the specified axes (1 and 2) for each layer
layer_mean = torch.mean(layer_array, dim=(1, 2))
# append the calculated mean to the list for the current frame
frame_features.append(layer_mean)
elif layer_name == 'pool':
frame = torch.squeeze(torch.tensor(frame))
# global mean and std
global_mean = torch.mean(frame, dim=0)
global_max = torch.max(frame, dim=0)[0]
global_std = torch.std(frame, dim=0)
# concatenate all pooling
combined_features = torch.hstack([frame, global_mean, global_max, global_std])
frame_features.append(combined_features)
# concatenate the layer means horizontally to form the processed frame
processed_frame = torch.hstack(frame_features)
averaged_frames.append(processed_frame)
averaged_frames = torch.stack(averaged_frames)
return averaged_frames
def flow_to_rgb(flow):
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
mag = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
# convert angle to hue
hue = ang * 180 / np.pi / 2
# create HSV
hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
hsv[..., 0] = hue
hsv[..., 1] = 255
hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
# convert HSV to RGB
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return rgb
def get_patch_diff(residual_frame, patch_size):
h, w = residual_frame.shape[2:] # Assuming (1, C, H, W) shape
h_adj = (h // patch_size) * patch_size
w_adj = (w // patch_size) * patch_size
residual_frame_adj = residual_frame[:, :, :h_adj, :w_adj]
# calculate absolute patch difference
diff = torch.zeros((h_adj // patch_size, w_adj // patch_size), device=residual_frame.device)
for i in range(0, h_adj, patch_size):
for j in range(0, w_adj, patch_size):
patch = residual_frame_adj[:, :, i:i + patch_size, j:j + patch_size]
# absolute sum
diff[i // patch_size, j // patch_size] = torch.sum(torch.abs(patch))
return diff
def extract_important_patches(residual_frame, diff, patch_size=16, target_size=224, top_n=196):
# find top n patches indices
patch_idx = torch.argsort(-diff.view(-1))
top_patches = [(idx // diff.shape[1], idx % diff.shape[1]) for idx in patch_idx[:top_n]]
sorted_idx = sorted(top_patches, key=lambda x: (x[0], x[1]))
imp_patches_img = torch.zeros((residual_frame.shape[1], target_size, target_size), dtype=residual_frame.dtype, device=residual_frame.device)
patches_per_row = target_size // patch_size # 14
# order the patch in the original location relation
positions = []
for idx, (y, x) in enumerate(sorted_idx):
patch = residual_frame[:, :, y * patch_size:(y + 1) * patch_size, x * patch_size:(x + 1) * patch_size]
# new patch location
row_idx = idx // patches_per_row
col_idx = idx % patches_per_row
start_y = row_idx * patch_size
start_x = col_idx * patch_size
imp_patches_img[:, start_y:start_y + patch_size, start_x:start_x + patch_size] = patch
positions.append((y.item(), x.item()))
return imp_patches_img, positions
def get_frame_patches(frame, positions, patch_size, target_size):
imp_patches_img = torch.zeros((frame.shape[1], target_size, target_size), dtype=frame.dtype, device=frame.device)
patches_per_row = target_size // patch_size
for idx, (y, x) in enumerate(positions):
start_y = y * patch_size
start_x = x * patch_size
end_y = start_y + patch_size
end_x = start_x + patch_size
patch = frame[:, :, start_y:end_y, start_x:end_x]
row_idx = idx // patches_per_row
col_idx = idx % patches_per_row
target_start_y = row_idx * patch_size
target_start_x = col_idx * patch_size
imp_patches_img[:, target_start_y:target_start_y + patch_size,
target_start_x:target_start_x + patch_size] = patch.squeeze(0)
return imp_patches_img
def process_patches(original_path, frag_name, residual, patch_size, target_size, top_n):
diff = get_patch_diff(residual, patch_size)
imp_patches, positions = extract_important_patches(residual, diff, patch_size, target_size, top_n)
if frag_name == 'frame_diff':
frag_path = original_path.replace('.png', '_residual_imp.png')
elif frag_name == 'optical_flow':
frag_path = original_path.replace('.png', '_residual_of_imp.png')
# cv2.imwrite(frag_path, imp_patches)
return frag_path, imp_patches, positions
def merge_fragments(diff_fragment, flow_fragment):
alpha = 0.5
merged_fragment = diff_fragment * alpha + flow_fragment * (1 - alpha)
return merged_fragment
def concatenate_features(frame_feature, residual_feature):
return torch.cat((frame_feature, residual_feature), dim=-1)
|