khang119966 commited on
Commit
1a04091
·
verified ·
1 Parent(s): 4fc9e5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -3
app.py CHANGED
@@ -28,6 +28,8 @@ import spaces
28
  import subprocess
29
  import os
30
  from moviepy.editor import VideoFileClip, AudioFileClip
 
 
31
 
32
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
33
 
@@ -446,7 +448,14 @@ model = AutoModel.from_pretrained(
446
  trust_remote_code=True,
447
  ).eval().cuda()
448
  tokenizer = AutoTokenizer.from_pretrained("khang119966/Vintern-1B-v3_5-explainableAI", trust_remote_code=True, use_fast=False)
449
-
 
 
 
 
 
 
 
450
  @spaces.GPU
451
  def generate_video(image, prompt, max_tokens):
452
  print(image)
@@ -517,11 +526,52 @@ def generate_video(image, prompt, max_tokens):
517
 
518
  input_token = predict_token_text
519
  heatmap_imgs.append(overlay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
- return "path_to_generated_video.mp4"
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
  with gr.Blocks() as demo:
524
- gr.Markdown("### Simple VLM Demo")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  with gr.Row():
527
  with gr.Column():
 
28
  import subprocess
29
  import os
30
  from moviepy.editor import VideoFileClip, AudioFileClip
31
+ import multiprocessing
32
+ import imageio
33
 
34
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
35
 
 
448
  trust_remote_code=True,
449
  ).eval().cuda()
450
  tokenizer = AutoTokenizer.from_pretrained("khang119966/Vintern-1B-v3_5-explainableAI", trust_remote_code=True, use_fast=False)
451
+
452
+ # Hàm bao để truyền vào multiprocessing
453
+ def generate_text_img_wrapper(args):
454
+ return generate_text_image_with_html2image(*args, image_width=500, min_height=1000)
455
+
456
+ def generate_hidden_img_wrapper(args):
457
+ return render_next_token_table_image(*args)
458
+
459
  @spaces.GPU
460
  def generate_video(image, prompt, max_tokens):
461
  print(image)
 
526
 
527
  input_token = predict_token_text
528
  heatmap_imgs.append(overlay)
529
+
530
+ # Dùng multiprocessing
531
+ with multiprocessing.Pool(processes=20) as pool:
532
+ text_imgs = pool.map(generate_text_img_wrapper, params_for_text)
533
+ hidden_imgs = pool.map(generate_hidden_img_wrapper, params_for_hidden)
534
+
535
+ for i in range(len(text_imgs)):
536
+ overlay = heatmap_imgs[i]
537
+ text_img = text_imgs[i]
538
+ predict_hidden_states = hidden_imgs[i]
539
+ overlay_adjusted = adjust_overlay(overlay, text_img)
540
+ predict_hidden_states = adjust_overlay(predict_hidden_states, text_img)
541
+ combined_image = np.hstack((overlay_adjusted, text_img, predict_hidden_states))
542
+ visualization_frames.append(combined_image)
543
 
544
+ resized_visualization_frames = []
545
+ for frame in visualization_frames:
546
+ frame = cv2.resize(frame,(visualization_frames[0].shape[1],visualization_frames[0].shape[0]))
547
+ resized_visualization_frames.append(frame)
548
+
549
+ # Lưu thành video MP4 bằng imageio
550
+ imageio.mimsave(
551
+ 'heatmap_animation.mp4',
552
+ resized_visualization_frames, # dạng RGB
553
+ fps=5
554
+ )
555
+
556
+ return "heatmap_animation.mp4"
557
 
558
  with gr.Blocks() as demo:
559
+ gr.Markdown("""## 🎥 Visualizing How Multimodal Models Think
560
+ This tool generates a video to **visualize how a multimodal model (image + text)** attends to different parts of an image while generating text.
561
+ ### 📌 What it does:
562
+ - Takes an input image and a text prompt.
563
+ - Shows how the model’s attention shifts on the image for each generated token.
564
+ - Helps explain the model’s behavior and decision-making.
565
+ ### 🖼️ Video layout (per frame):
566
+ Each frame in the video includes:
567
+ 1. 🔥 **Heatmap over image**: Shows which area the model focuses on.
568
+ 2. 📝 **Generated text**: With old context, current token highlighted.
569
+ 3. 📊 **Token prediction table**: Shows the model’s top next-token guesses.
570
+ ### 🎯 Use cases:
571
+ - Research explainability of vision-language models.
572
+ - Debugging or interpreting model outputs.
573
+ - Creating educational visualizations.
574
+ """)
575
 
576
  with gr.Row():
577
  with gr.Column():