lorocksUMD commited on
Commit
3d9aa0c
·
verified ·
1 Parent(s): 64fda16

Update DenseAV/denseav/plotting.py

Browse files
Files changed (1) hide show
  1. DenseAV/denseav/plotting.py +246 -244
DenseAV/denseav/plotting.py CHANGED
@@ -1,244 +1,246 @@
1
- import os
2
- from collections import defaultdict
3
-
4
- import matplotlib.colors as mcolors
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- import scipy.io.wavfile as wavfile
8
- import torch
9
- import torch.nn.functional as F
10
- import torchvision
11
- from moviepy.editor import VideoFileClip, AudioFileClip
12
- from base64 import b64encode
13
- from DenseAV.denseav.shared import pca
14
-
15
-
16
- def write_video_with_audio(video_frames, audio_array, video_fps, audio_fps, output_path):
17
- """
18
- Writes video frames and audio to a specified path.
19
-
20
- Parameters:
21
- - video_frames: torch.Tensor of shape (num_frames, height, width, channels)
22
- - audio_array: torch.Tensor of shape (num_samples, num_channels)
23
- - video_fps: int, frames per second of the video
24
- - audio_fps: int, sample rate of the audio
25
- - output_path: str, path to save the final video with audio
26
- """
27
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
28
-
29
- temp_video_path = output_path.replace('.mp4', '_temp.mp4')
30
- temp_audio_path = output_path.replace('.mp4', '_temp_audio.wav')
31
- video_options = {
32
- 'crf': '23',
33
- 'preset': 'slow',
34
- 'bit_rate': '1000k'}
35
-
36
- if audio_array is not None:
37
- torchvision.io.write_video(
38
- filename=temp_video_path,
39
- video_array=video_frames,
40
- fps=video_fps,
41
- options=video_options
42
- )
43
-
44
- wavfile.write(temp_audio_path, audio_fps, audio_array.cpu().to(torch.float64).permute(1, 0).numpy())
45
- video_clip = VideoFileClip(temp_video_path)
46
- audio_clip = AudioFileClip(temp_audio_path)
47
- final_clip = video_clip.set_audio(audio_clip)
48
- final_clip.write_videofile(output_path, codec='libx264', verbose=False)
49
- os.remove(temp_video_path)
50
- os.remove(temp_audio_path)
51
- else:
52
- torchvision.io.write_video(
53
- filename=output_path,
54
- video_array=video_frames,
55
- fps=video_fps,
56
- options=video_options
57
- )
58
-
59
-
60
- def alpha_blend_layers(layers):
61
- blended_image = layers[0]
62
- for layer in layers[1:]:
63
- rgb1, alpha1 = blended_image[:, :3, :, :], blended_image[:, 3:4, :, :]
64
- rgb2, alpha2 = layer[:, :3, :, :], layer[:, 3:4, :, :]
65
- alpha_out = alpha2 + alpha1 * (1 - alpha2)
66
- rgb_out = (rgb2 * alpha2 + rgb1 * alpha1 * (1 - alpha2)) / alpha_out.clamp(min=1e-7)
67
- blended_image = torch.cat([rgb_out, alpha_out], dim=1)
68
- return (blended_image[:, :3] * 255).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
69
-
70
-
71
- def _prep_sims_for_plotting(sim_by_head, frames):
72
- with torch.no_grad():
73
- results = defaultdict(list)
74
- n_frames, _, vh, vw = frames.shape
75
-
76
- sims = sim_by_head.max(dim=1).values
77
-
78
- n_audio_feats = sims.shape[-1]
79
- for frame_num in range(n_frames):
80
- selected_audio_feat = int((frame_num / n_frames) * n_audio_feats)
81
-
82
- selected_sim = F.interpolate(
83
- sims[frame_num, :, :, selected_audio_feat].unsqueeze(0).unsqueeze(0),
84
- size=(vh, vw),
85
- mode="bicubic")
86
-
87
- results["sims_all"].append(selected_sim)
88
-
89
- for head in range(sim_by_head.shape[1]):
90
- selected_sim = F.interpolate(
91
- sim_by_head[frame_num, head, :, :, selected_audio_feat].unsqueeze(0).unsqueeze(0),
92
- size=(vh, vw),
93
- mode="bicubic")
94
- results[f"sims_{head + 1}"].append(selected_sim)
95
-
96
- results = {k: torch.cat(v, dim=0) for k, v in results.items()}
97
- return results
98
-
99
-
100
- def get_plasma_with_alpha():
101
- plasma = plt.cm.plasma(np.linspace(0, 1, 256))
102
- alphas = np.linspace(0, 1, 256)
103
- plasma_with_alpha = np.zeros((256, 4))
104
- plasma_with_alpha[:, 0:3] = plasma[:, 0:3]
105
- plasma_with_alpha[:, 3] = alphas
106
- return mcolors.ListedColormap(plasma_with_alpha)
107
-
108
-
109
- def get_inferno_with_alpha_2(alpha=0.5, k=30):
110
- k_fraction = k / 100.0
111
- custom_cmap = np.zeros((256, 4))
112
- threshold_index = int(k_fraction * 256)
113
- custom_cmap[:threshold_index, :3] = 0 # RGB values for black
114
- custom_cmap[:threshold_index, 3] = alpha # Alpha value
115
- remaining_inferno = plt.cm.inferno(np.linspace(0, 1, 256 - threshold_index))
116
- custom_cmap[threshold_index:, :3] = remaining_inferno[:, :3]
117
- custom_cmap[threshold_index:, 3] = alpha # Alpha value
118
- return mcolors.ListedColormap(custom_cmap)
119
-
120
-
121
- def get_inferno_with_alpha():
122
- plasma = plt.cm.inferno(np.linspace(0, 1, 256))
123
- alphas = np.linspace(0, 1, 256)
124
- plasma_with_alpha = np.zeros((256, 4))
125
- plasma_with_alpha[:, 0:3] = plasma[:, 0:3]
126
- plasma_with_alpha[:, 3] = alphas
127
- return mcolors.ListedColormap(plasma_with_alpha)
128
-
129
-
130
- red_cmap = mcolors.LinearSegmentedColormap('RedMap', segmentdata={
131
- 'red': [(0.0, 1.0, 1.0), (1.0, 1.0, 1.0)],
132
- 'green': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
133
- 'blue': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
134
- 'alpha': [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
135
- })
136
-
137
- blue_cmap = mcolors.LinearSegmentedColormap('BlueMap', segmentdata={
138
- 'red': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
139
- 'green': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
140
- 'blue': [(0.0, 1.0, 1.0), (1.0, 1.0, 1.0)],
141
- 'alpha': [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
142
- })
143
-
144
-
145
- def plot_attention_video(sims_by_head, frames, audio, video_fps, audio_fps, output_filename):
146
- prepped_sims = _prep_sims_for_plotting(sims_by_head, frames)
147
- n_frames, _, vh, vw = frames.shape
148
- sims_all = prepped_sims["sims_all"].clamp_min(0)
149
- sims_all -= sims_all.min()
150
- sims_all = sims_all / sims_all.max()
151
- cmap = get_inferno_with_alpha()
152
- layer1 = torch.cat([frames, torch.ones(n_frames, 1, vh, vw)], axis=1)
153
- layer2 = torch.tensor(cmap(sims_all.squeeze().detach().cpu())).permute(0, 3, 1, 2)
154
- write_video_with_audio(
155
- alpha_blend_layers([layer1, layer2]),
156
- audio,
157
- video_fps,
158
- audio_fps,
159
- output_filename)
160
-
161
-
162
- def plot_2head_attention_video(sims_by_head, frames, audio, video_fps, audio_fps, output_filename):
163
- prepped_sims = _prep_sims_for_plotting(sims_by_head, frames)
164
- sims_1 = prepped_sims["sims_1"]
165
- sims_2 = prepped_sims["sims_2"]
166
-
167
- n_frames, _, vh, vw = frames.shape
168
-
169
- mask = sims_1 > sims_2
170
- sims_1 *= mask
171
- sims_2 *= (~mask)
172
-
173
- sims_1 = sims_1.clamp_min(0)
174
- sims_1 -= sims_1.min()
175
- sims_1 = sims_1 / sims_1.max()
176
-
177
- sims_2 = sims_2.clamp_min(0)
178
- sims_2 -= sims_2.min()
179
- sims_2 = sims_2 / sims_2.max()
180
-
181
- layer1 = torch.cat([frames, torch.ones(n_frames, 1, vh, vw)], axis=1)
182
- layer2_head1 = torch.tensor(red_cmap(sims_1.squeeze().detach().cpu())).permute(0, 3, 1, 2)
183
- layer2_head2 = torch.tensor(blue_cmap(sims_2.squeeze().detach().cpu())).permute(0, 3, 1, 2)
184
-
185
- write_video_with_audio(
186
- alpha_blend_layers([layer1, layer2_head1, layer2_head2]),
187
- audio,
188
- video_fps,
189
- audio_fps,
190
- output_filename)
191
-
192
-
193
- def plot_feature_video(image_feats,
194
- audio_feats,
195
- frames,
196
- audio,
197
- video_fps,
198
- audio_fps,
199
- video_filename,
200
- audio_filename):
201
- with torch.no_grad():
202
- image_feats_ = image_feats.cpu()
203
- audio_feats_ = audio_feats.cpu()
204
- [red_img_feats, red_audio_feats], _ = pca([
205
- image_feats_,
206
- audio_feats_, # .tile(image_feats_.shape[0], 1, 1, 1)
207
- ])
208
- _, _, vh, vw = frames.shape
209
- red_img_feats = F.interpolate(red_img_feats, size=(vh, vw), mode="bicubic")
210
- red_audio_feats = red_audio_feats[0].unsqueeze(0)
211
- red_audio_feats = F.interpolate(red_audio_feats, size=(50, red_img_feats.shape[0]), mode="bicubic")
212
-
213
- write_video_with_audio(
214
- (red_img_feats.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8),
215
- audio,
216
- video_fps,
217
- audio_fps,
218
- video_filename)
219
-
220
- red_audio_feats_expanded = red_audio_feats.tile(red_img_feats.shape[0], 1, 1, 1)
221
- red_audio_feats_expanded = F.interpolate(red_audio_feats_expanded, scale_factor=6, mode="bicubic")
222
- for i in range(red_img_feats.shape[0]):
223
- center_index = i * 6
224
- min_index = max(center_index - 2, 0)
225
- max_index = min(center_index + 2, red_audio_feats_expanded.shape[-1])
226
- red_audio_feats_expanded[i, :, :, min_index:max_index] = 1
227
-
228
- write_video_with_audio(
229
- (red_audio_feats_expanded.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8),
230
- audio,
231
- video_fps,
232
- audio_fps,
233
- audio_filename)
234
-
235
-
236
- def display_video_in_notebook(path):
237
- from IPython.display import HTML, display
238
- mp4 = open(path, 'rb').read()
239
- data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
240
- display(HTML("""
241
- <video width=400 controls>
242
- <source src="%s" type="video/mp4">
243
- </video>
244
- """ % data_url))
 
 
 
1
+ import os
2
+ from collections import defaultdict
3
+
4
+ import matplotlib.colors as mcolors
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import scipy.io.wavfile as wavfile
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision
11
+ from moviepy import *
12
+ from moviepy.editor import VideoFileClip, AudioFileClip
13
+ from base64 import b64encode
14
+ from DenseAV.denseav.shared import pca
15
+
16
+
17
+
18
+ def write_video_with_audio(video_frames, audio_array, video_fps, audio_fps, output_path):
19
+ """
20
+ Writes video frames and audio to a specified path.
21
+
22
+ Parameters:
23
+ - video_frames: torch.Tensor of shape (num_frames, height, width, channels)
24
+ - audio_array: torch.Tensor of shape (num_samples, num_channels)
25
+ - video_fps: int, frames per second of the video
26
+ - audio_fps: int, sample rate of the audio
27
+ - output_path: str, path to save the final video with audio
28
+ """
29
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
30
+
31
+ temp_video_path = output_path.replace('.mp4', '_temp.mp4')
32
+ temp_audio_path = output_path.replace('.mp4', '_temp_audio.wav')
33
+ video_options = {
34
+ 'crf': '23',
35
+ 'preset': 'slow',
36
+ 'bit_rate': '1000k'}
37
+
38
+ if audio_array is not None:
39
+ torchvision.io.write_video(
40
+ filename=temp_video_path,
41
+ video_array=video_frames,
42
+ fps=video_fps,
43
+ options=video_options
44
+ )
45
+
46
+ wavfile.write(temp_audio_path, audio_fps, audio_array.cpu().to(torch.float64).permute(1, 0).numpy())
47
+ video_clip = VideoFileClip(temp_video_path)
48
+ audio_clip = AudioFileClip(temp_audio_path)
49
+ final_clip = video_clip.set_audio(audio_clip)
50
+ final_clip.write_videofile(output_path, codec='libx264', verbose=False)
51
+ os.remove(temp_video_path)
52
+ os.remove(temp_audio_path)
53
+ else:
54
+ torchvision.io.write_video(
55
+ filename=output_path,
56
+ video_array=video_frames,
57
+ fps=video_fps,
58
+ options=video_options
59
+ )
60
+
61
+
62
+ def alpha_blend_layers(layers):
63
+ blended_image = layers[0]
64
+ for layer in layers[1:]:
65
+ rgb1, alpha1 = blended_image[:, :3, :, :], blended_image[:, 3:4, :, :]
66
+ rgb2, alpha2 = layer[:, :3, :, :], layer[:, 3:4, :, :]
67
+ alpha_out = alpha2 + alpha1 * (1 - alpha2)
68
+ rgb_out = (rgb2 * alpha2 + rgb1 * alpha1 * (1 - alpha2)) / alpha_out.clamp(min=1e-7)
69
+ blended_image = torch.cat([rgb_out, alpha_out], dim=1)
70
+ return (blended_image[:, :3] * 255).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
71
+
72
+
73
+ def _prep_sims_for_plotting(sim_by_head, frames):
74
+ with torch.no_grad():
75
+ results = defaultdict(list)
76
+ n_frames, _, vh, vw = frames.shape
77
+
78
+ sims = sim_by_head.max(dim=1).values
79
+
80
+ n_audio_feats = sims.shape[-1]
81
+ for frame_num in range(n_frames):
82
+ selected_audio_feat = int((frame_num / n_frames) * n_audio_feats)
83
+
84
+ selected_sim = F.interpolate(
85
+ sims[frame_num, :, :, selected_audio_feat].unsqueeze(0).unsqueeze(0),
86
+ size=(vh, vw),
87
+ mode="bicubic")
88
+
89
+ results["sims_all"].append(selected_sim)
90
+
91
+ for head in range(sim_by_head.shape[1]):
92
+ selected_sim = F.interpolate(
93
+ sim_by_head[frame_num, head, :, :, selected_audio_feat].unsqueeze(0).unsqueeze(0),
94
+ size=(vh, vw),
95
+ mode="bicubic")
96
+ results[f"sims_{head + 1}"].append(selected_sim)
97
+
98
+ results = {k: torch.cat(v, dim=0) for k, v in results.items()}
99
+ return results
100
+
101
+
102
+ def get_plasma_with_alpha():
103
+ plasma = plt.cm.plasma(np.linspace(0, 1, 256))
104
+ alphas = np.linspace(0, 1, 256)
105
+ plasma_with_alpha = np.zeros((256, 4))
106
+ plasma_with_alpha[:, 0:3] = plasma[:, 0:3]
107
+ plasma_with_alpha[:, 3] = alphas
108
+ return mcolors.ListedColormap(plasma_with_alpha)
109
+
110
+
111
+ def get_inferno_with_alpha_2(alpha=0.5, k=30):
112
+ k_fraction = k / 100.0
113
+ custom_cmap = np.zeros((256, 4))
114
+ threshold_index = int(k_fraction * 256)
115
+ custom_cmap[:threshold_index, :3] = 0 # RGB values for black
116
+ custom_cmap[:threshold_index, 3] = alpha # Alpha value
117
+ remaining_inferno = plt.cm.inferno(np.linspace(0, 1, 256 - threshold_index))
118
+ custom_cmap[threshold_index:, :3] = remaining_inferno[:, :3]
119
+ custom_cmap[threshold_index:, 3] = alpha # Alpha value
120
+ return mcolors.ListedColormap(custom_cmap)
121
+
122
+
123
+ def get_inferno_with_alpha():
124
+ plasma = plt.cm.inferno(np.linspace(0, 1, 256))
125
+ alphas = np.linspace(0, 1, 256)
126
+ plasma_with_alpha = np.zeros((256, 4))
127
+ plasma_with_alpha[:, 0:3] = plasma[:, 0:3]
128
+ plasma_with_alpha[:, 3] = alphas
129
+ return mcolors.ListedColormap(plasma_with_alpha)
130
+
131
+
132
+ red_cmap = mcolors.LinearSegmentedColormap('RedMap', segmentdata={
133
+ 'red': [(0.0, 1.0, 1.0), (1.0, 1.0, 1.0)],
134
+ 'green': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
135
+ 'blue': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
136
+ 'alpha': [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
137
+ })
138
+
139
+ blue_cmap = mcolors.LinearSegmentedColormap('BlueMap', segmentdata={
140
+ 'red': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
141
+ 'green': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
142
+ 'blue': [(0.0, 1.0, 1.0), (1.0, 1.0, 1.0)],
143
+ 'alpha': [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
144
+ })
145
+
146
+
147
+ def plot_attention_video(sims_by_head, frames, audio, video_fps, audio_fps, output_filename):
148
+ prepped_sims = _prep_sims_for_plotting(sims_by_head, frames)
149
+ n_frames, _, vh, vw = frames.shape
150
+ sims_all = prepped_sims["sims_all"].clamp_min(0)
151
+ sims_all -= sims_all.min()
152
+ sims_all = sims_all / sims_all.max()
153
+ cmap = get_inferno_with_alpha()
154
+ layer1 = torch.cat([frames, torch.ones(n_frames, 1, vh, vw)], axis=1)
155
+ layer2 = torch.tensor(cmap(sims_all.squeeze().detach().cpu())).permute(0, 3, 1, 2)
156
+ write_video_with_audio(
157
+ alpha_blend_layers([layer1, layer2]),
158
+ audio,
159
+ video_fps,
160
+ audio_fps,
161
+ output_filename)
162
+
163
+
164
+ def plot_2head_attention_video(sims_by_head, frames, audio, video_fps, audio_fps, output_filename):
165
+ prepped_sims = _prep_sims_for_plotting(sims_by_head, frames)
166
+ sims_1 = prepped_sims["sims_1"]
167
+ sims_2 = prepped_sims["sims_2"]
168
+
169
+ n_frames, _, vh, vw = frames.shape
170
+
171
+ mask = sims_1 > sims_2
172
+ sims_1 *= mask
173
+ sims_2 *= (~mask)
174
+
175
+ sims_1 = sims_1.clamp_min(0)
176
+ sims_1 -= sims_1.min()
177
+ sims_1 = sims_1 / sims_1.max()
178
+
179
+ sims_2 = sims_2.clamp_min(0)
180
+ sims_2 -= sims_2.min()
181
+ sims_2 = sims_2 / sims_2.max()
182
+
183
+ layer1 = torch.cat([frames, torch.ones(n_frames, 1, vh, vw)], axis=1)
184
+ layer2_head1 = torch.tensor(red_cmap(sims_1.squeeze().detach().cpu())).permute(0, 3, 1, 2)
185
+ layer2_head2 = torch.tensor(blue_cmap(sims_2.squeeze().detach().cpu())).permute(0, 3, 1, 2)
186
+
187
+ write_video_with_audio(
188
+ alpha_blend_layers([layer1, layer2_head1, layer2_head2]),
189
+ audio,
190
+ video_fps,
191
+ audio_fps,
192
+ output_filename)
193
+
194
+
195
+ def plot_feature_video(image_feats,
196
+ audio_feats,
197
+ frames,
198
+ audio,
199
+ video_fps,
200
+ audio_fps,
201
+ video_filename,
202
+ audio_filename):
203
+ with torch.no_grad():
204
+ image_feats_ = image_feats.cpu()
205
+ audio_feats_ = audio_feats.cpu()
206
+ [red_img_feats, red_audio_feats], _ = pca([
207
+ image_feats_,
208
+ audio_feats_, # .tile(image_feats_.shape[0], 1, 1, 1)
209
+ ])
210
+ _, _, vh, vw = frames.shape
211
+ red_img_feats = F.interpolate(red_img_feats, size=(vh, vw), mode="bicubic")
212
+ red_audio_feats = red_audio_feats[0].unsqueeze(0)
213
+ red_audio_feats = F.interpolate(red_audio_feats, size=(50, red_img_feats.shape[0]), mode="bicubic")
214
+
215
+ write_video_with_audio(
216
+ (red_img_feats.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8),
217
+ audio,
218
+ video_fps,
219
+ audio_fps,
220
+ video_filename)
221
+
222
+ red_audio_feats_expanded = red_audio_feats.tile(red_img_feats.shape[0], 1, 1, 1)
223
+ red_audio_feats_expanded = F.interpolate(red_audio_feats_expanded, scale_factor=6, mode="bicubic")
224
+ for i in range(red_img_feats.shape[0]):
225
+ center_index = i * 6
226
+ min_index = max(center_index - 2, 0)
227
+ max_index = min(center_index + 2, red_audio_feats_expanded.shape[-1])
228
+ red_audio_feats_expanded[i, :, :, min_index:max_index] = 1
229
+
230
+ write_video_with_audio(
231
+ (red_audio_feats_expanded.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8),
232
+ audio,
233
+ video_fps,
234
+ audio_fps,
235
+ audio_filename)
236
+
237
+
238
+ def display_video_in_notebook(path):
239
+ from IPython.display import HTML, display
240
+ mp4 = open(path, 'rb').read()
241
+ data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
242
+ display(HTML("""
243
+ <video width=400 controls>
244
+ <source src="%s" type="video/mp4">
245
+ </video>
246
+ """ % data_url))