Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -285,13 +285,13 @@ def visualize_attention_hiddenstate(attention_tensor, head=None, start_img_token
|
|
285 |
averaged_layer = np.mean(last_8_layers,axis=0) # Trung bình 8 layer cuối
|
286 |
|
287 |
if head is None:
|
288 |
-
averaged_attention = averaged_layer.mean(axis=1)
|
289 |
else:
|
290 |
-
averaged_attention = averaged_layer[:, head, :, :]
|
291 |
|
292 |
heat_maps = []
|
293 |
top_5_tokens = []
|
294 |
-
|
295 |
for i in range(len(averaged_attention)): # Duyệt qua các beam
|
296 |
h_target_aspect_ratio = target_aspect_ratio[1] if target_aspect_ratio[1] != 0 else 1
|
297 |
w_target_aspect_ratio = target_aspect_ratio[0] if target_aspect_ratio[0] != 0 else 1
|
@@ -306,11 +306,13 @@ def visualize_attention_hiddenstate(attention_tensor, head=None, start_img_token
|
|
306 |
|
307 |
|
308 |
# Reshape lại attention để vẽ heatmap
|
|
|
309 |
img_atten_score = img_atten_score.reshape(h_target_aspect_ratio, w_target_aspect_ratio, 16, 16)
|
310 |
img_atten_score = np.transpose(img_atten_score, (0, 2, 1, 3)).reshape(h_target_aspect_ratio * 16, w_target_aspect_ratio * 16)
|
311 |
|
312 |
img_atten_score = np.power(img_atten_score, 0.9)
|
313 |
-
|
|
|
314 |
|
315 |
return heat_maps, top_5_tokens
|
316 |
|
@@ -379,6 +381,21 @@ def generate_video(image, prompt, max_tokens):
|
|
379 |
response, query = model.chat(tokenizer, pixel_values, '<image>\n'+prompt, generation_config, return_history=False, \
|
380 |
attention_visualize=True,last_visualize_layers=7,raw_image_path=image,target_aspect_ratio=target_aspect_ratio)
|
381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
generation_output = response
|
383 |
raw_image_path = image
|
384 |
|
|
|
285 |
averaged_layer = np.mean(last_8_layers,axis=0) # Trung bình 8 layer cuối
|
286 |
|
287 |
if head is None:
|
288 |
+
averaged_attention = averaged_layer.mean(axis=1) # Trung bình qua các head
|
289 |
else:
|
290 |
+
averaged_attention = averaged_layer[:, head, :, :] # Chọn head cụ thể
|
291 |
|
292 |
heat_maps = []
|
293 |
top_5_tokens = []
|
294 |
+
|
295 |
for i in range(len(averaged_attention)): # Duyệt qua các beam
|
296 |
h_target_aspect_ratio = target_aspect_ratio[1] if target_aspect_ratio[1] != 0 else 1
|
297 |
w_target_aspect_ratio = target_aspect_ratio[0] if target_aspect_ratio[0] != 0 else 1
|
|
|
306 |
|
307 |
|
308 |
# Reshape lại attention để vẽ heatmap
|
309 |
+
|
310 |
img_atten_score = img_atten_score.reshape(h_target_aspect_ratio, w_target_aspect_ratio, 16, 16)
|
311 |
img_atten_score = np.transpose(img_atten_score, (0, 2, 1, 3)).reshape(h_target_aspect_ratio * 16, w_target_aspect_ratio * 16)
|
312 |
|
313 |
img_atten_score = np.power(img_atten_score, 0.9)
|
314 |
+
|
315 |
+
heat_maps.append(img_atten_score)
|
316 |
|
317 |
return heat_maps, top_5_tokens
|
318 |
|
|
|
381 |
response, query = model.chat(tokenizer, pixel_values, '<image>\n'+prompt, generation_config, return_history=False, \
|
382 |
attention_visualize=True,last_visualize_layers=7,raw_image_path=image,target_aspect_ratio=target_aspect_ratio)
|
383 |
|
384 |
+
###### GET GOOD BEAM #####
|
385 |
+
response_attentions_list = []
|
386 |
+
response_hidden_states_list = []
|
387 |
+
for index in range(len(response.beam_indices[0])):
|
388 |
+
beam_indice = response.beam_indices[0][index]
|
389 |
+
layer_response_attentions_list = []
|
390 |
+
layer_response_hidden_states_list = []
|
391 |
+
for layer_index in range(len(response.attentions[index])):
|
392 |
+
layer_response_attentions_list.append(torch.unsqueeze(response.attentions[index][layer_index][beam_indice],0))
|
393 |
+
layer_response_hidden_states_list.append(torch.unsqueeze(response.hidden_states[index][layer_index][beam_indice],0))
|
394 |
+
response_attentions_list.append(layer_response_attentions_list)
|
395 |
+
response_hidden_states_list.append(layer_response_hidden_states_list)
|
396 |
+
response.attentions = response_attentions_list
|
397 |
+
response.hidden_states = response_hidden_states_list
|
398 |
+
|
399 |
generation_output = response
|
400 |
raw_image_path = image
|
401 |
|