File size: 7,579 Bytes
17f753b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
import os
import cv2
import numpy as np
from extractor import visualise_resnet, visualise_resnet_layer, visualise_vit_layer


def get_deep_feature(network_name, video_name, frame, frame_number, model, device, layer_name):
    if network_name == 'resnet50':
        if layer_name == 'layerstack':
            all_layers = ['resnet50.conv1',
                          'resnet50.layer1[0]', 'resnet50.layer1[1]', 'resnet50.layer1[2]',
                          'resnet50.layer2[0]', 'resnet50.layer2[1]', 'resnet50.layer2[2]', 'resnet50.layer2[3]',
                          'resnet50.layer3[0]', 'resnet50.layer3[1]', 'resnet50.layer3[2]', 'resnet50.layer3[3]',
                          'resnet50.layer4[0]', 'resnet50.layer4[1]', 'resnet50.layer4[2]']
            resnet50 = model
            activations_dict, _, total_flops, total_params = visualise_resnet.process_video_frame(video_name, frame, frame_number, all_layers, resnet50, device)

        elif layer_name == 'pool':
            visual_layer = 'resnet50.avgpool' # before avg_pool
            resnet50 = model
            activations_dict, _, total_flops, total_params = visualise_resnet_layer.process_video_frame(video_name, frame, frame_number, visual_layer, resnet50, device)

    elif network_name == 'vit':
        patch_size = 16
        activations_dict, _, total_flops, total_params = visualise_vit_layer.process_video_frame(video_name, frame, frame_number, model, patch_size, device)

    return activations_dict, total_flops, total_params


def process_video_feature(video_feature, network_name, layer_name):
    # initialize an empty list to store processed frames
    averaged_frames = []
    # iterate through each frame in the video_feature
    for frame in video_feature:
        frame_features = []

        if network_name == 'vit':
            # global mean and std
            global_mean = torch.mean(frame, dim=0)
            global_max = torch.max(frame, dim=0)[0]
            global_std = torch.std(frame, dim=0)
            # concatenate all pooling
            combined_features = torch.hstack([global_mean, global_max, global_std])
            frame_features.append(combined_features)

        elif network_name == 'resnet50':
            if layer_name == 'layerstack':
                # iterate through each layer in the current framex
                for layer_array in frame.values():
                    # calculate the mean along the specified axes (1 and 2) for each layer
                    layer_mean = torch.mean(layer_array, dim=(1, 2))
                    # append the calculated mean to the list for the current frame
                    frame_features.append(layer_mean)
            elif layer_name == 'pool':
                frame = torch.squeeze(torch.tensor(frame))
                # global mean and std
                global_mean = torch.mean(frame, dim=0)
                global_max = torch.max(frame, dim=0)[0]
                global_std = torch.std(frame, dim=0)
                # concatenate all pooling
                combined_features = torch.hstack([frame, global_mean, global_max, global_std])
                frame_features.append(combined_features)

        # concatenate the layer means horizontally to form the processed frame
        processed_frame = torch.hstack(frame_features)
        averaged_frames.append(processed_frame)

    averaged_frames = torch.stack(averaged_frames)
    return averaged_frames


def flow_to_rgb(flow):
    mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
    mag = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
    # convert angle to hue
    hue = ang * 180 / np.pi / 2

    # create HSV
    hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
    hsv[..., 0] = hue
    hsv[..., 1] = 255
    hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
    # convert HSV to RGB
    rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    return rgb

def get_patch_diff(residual_frame, patch_size):
    h, w = residual_frame.shape[2:]  # Assuming (1, C, H, W) shape
    h_adj = (h // patch_size) * patch_size
    w_adj = (w // patch_size) * patch_size
    residual_frame_adj = residual_frame[:, :, :h_adj, :w_adj]
    # calculate absolute patch difference
    diff = torch.zeros((h_adj // patch_size, w_adj // patch_size), device=residual_frame.device)
    for i in range(0, h_adj, patch_size):
        for j in range(0, w_adj, patch_size):
            patch = residual_frame_adj[:, :, i:i + patch_size, j:j + patch_size]
            # absolute sum
            diff[i // patch_size, j // patch_size] = torch.sum(torch.abs(patch))
    return diff

def extract_important_patches(residual_frame, diff, patch_size=16, target_size=224, top_n=196):
    # find top n patches indices
    patch_idx = torch.argsort(-diff.view(-1))
    top_patches = [(idx // diff.shape[1], idx % diff.shape[1]) for idx in patch_idx[:top_n]]
    sorted_idx = sorted(top_patches, key=lambda x: (x[0], x[1]))

    imp_patches_img = torch.zeros((residual_frame.shape[1], target_size, target_size), dtype=residual_frame.dtype, device=residual_frame.device)
    patches_per_row = target_size // patch_size  # 14
    # order the patch in the original location relation
    positions = []
    for idx, (y, x) in enumerate(sorted_idx):
        patch = residual_frame[:, :, y * patch_size:(y + 1) * patch_size, x * patch_size:(x + 1) * patch_size]
        # new patch location
        row_idx = idx // patches_per_row
        col_idx = idx % patches_per_row
        start_y = row_idx * patch_size
        start_x = col_idx * patch_size
        imp_patches_img[:, start_y:start_y + patch_size, start_x:start_x + patch_size] = patch
        positions.append((y.item(), x.item()))
    return imp_patches_img, positions

def get_frame_patches(frame, positions, patch_size, target_size):
    imp_patches_img = torch.zeros((frame.shape[1], target_size, target_size), dtype=frame.dtype, device=frame.device)
    patches_per_row = target_size // patch_size

    for idx, (y, x) in enumerate(positions):
        start_y = y * patch_size
        start_x = x * patch_size
        end_y = start_y + patch_size
        end_x = start_x + patch_size

        patch = frame[:, :, start_y:end_y, start_x:end_x]
        row_idx = idx // patches_per_row
        col_idx = idx % patches_per_row
        target_start_y = row_idx * patch_size
        target_start_x = col_idx * patch_size

        imp_patches_img[:, target_start_y:target_start_y + patch_size,
        target_start_x:target_start_x + patch_size] = patch.squeeze(0)
    return imp_patches_img

def process_patches(original_path, frag_name, residual, patch_size, target_size, top_n):
    diff = get_patch_diff(residual, patch_size)
    imp_patches, positions = extract_important_patches(residual, diff, patch_size, target_size, top_n)
    if frag_name == 'frame_diff':
        frag_path = original_path.replace('.png', '_residual_imp.png')
    elif frag_name == 'optical_flow':
        frag_path = original_path.replace('.png', '_residual_of_imp.png')
    # cv2.imwrite(frag_path, imp_patches)
    return frag_path, imp_patches, positions

def merge_fragments(diff_fragment, flow_fragment):
    alpha = 0.5
    merged_fragment = diff_fragment * alpha + flow_fragment * (1 - alpha)
    return merged_fragment

def concatenate_features(frame_feature, residual_feature):
    return torch.cat((frame_feature, residual_feature), dim=-1)