AustingDong commited on
Commit
f59a9b2
·
1 Parent(s): 369c141

add visual method changing: functions and pooling aggregated

Browse files
Files changed (4) hide show
  1. app.py +60 -41
  2. demo/model_utils.py +1 -1
  3. demo/modified_attn.py +221 -0
  4. demo/visualization.py +25 -22
app.py CHANGED
@@ -5,6 +5,7 @@ from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from demo.visualization import generate_gradcam, VisualizationJanus, VisualizationClip, VisualizationChartGemma, VisualizationLLaVA
7
  from demo.model_utils import Clip_Utils, Janus_Utils, LLaVA_Utils, ChartGemma_Utils, add_title_to_image
 
8
 
9
  import numpy as np
10
  import matplotlib.pyplot as plt
@@ -53,7 +54,7 @@ def clean():
53
  @spaces.GPU(duration=120)
54
  def multimodal_understanding(model_type,
55
  activation_map_method,
56
- visual_pooling_method,
57
  image, question, seed, top_p, temperature, target_token_idx,
58
  visualization_layer_min, visualization_layer_max, focus, response_type, chart_type):
59
  # Clear CUDA cache before generating
@@ -83,7 +84,7 @@ def multimodal_understanding(model_type,
83
  else:
84
  target_layers = [all_layers[visualization_layer_min-1]]
85
  grad_cam = VisualizationClip(clip_utils.model, target_layers)
86
- cam, outputs, grid_size = grad_cam.generate_cam(inputs, target_token_idx=0, visual_pooling_method=visual_pooling_method)
87
  cam = cam.to("cpu")
88
  cam = [generate_gradcam(cam, image, size=(224, 224))]
89
  grad_cam.remove_hooks()
@@ -144,7 +145,7 @@ def multimodal_understanding(model_type,
144
  cam = []
145
  if focus == "Visual Encoder":
146
  if target_token_idx != -1:
147
- cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
148
  cam_grid = cam_tensors.reshape(grid_size, grid_size)
149
  cam_i = generate_gradcam(cam_grid, image)
150
  cam_i = add_title_to_image(cam_i, input_ids_decoded[start + target_token_idx])
@@ -159,7 +160,7 @@ def multimodal_understanding(model_type,
159
  gradcam = VisualizationLLaVA(vl_gpt, target_layers)
160
  elif model_name.split('-')[0] == "ChartGemma":
161
  gradcam = VisualizationChartGemma(vl_gpt, target_layers)
162
- cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, i, visual_pooling_method, focus)
163
  cam_grid = cam_tensors.reshape(grid_size, grid_size)
164
  cam_i = generate_gradcam(cam_grid, image)
165
  cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
@@ -167,7 +168,7 @@ def multimodal_understanding(model_type,
167
  gradcam.remove_hooks()
168
  i += 1
169
  else:
170
- cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
171
  if target_token_idx != -1:
172
  input_text_decoded = input_ids_decoded[start + target_token_idx]
173
  for i, cam_tensor in enumerate(cam_tensors):
@@ -190,13 +191,11 @@ def multimodal_understanding(model_type,
190
 
191
  # Collect Results
192
  RESULTS_ROOT = "./results"
193
- FILES_ROOT = f"{RESULTS_ROOT}/{model_name}/{focus}/{chart_type}/layer{visualization_layer_min}-{visualization_layer_max}"
194
  os.makedirs(FILES_ROOT, exist_ok=True)
195
- if focus == "Visual Encoder":
196
- cam[0].save(f"{FILES_ROOT}/{visual_pooling_method}.png")
197
- else:
198
- for i, cam_p in enumerate(cam):
199
- cam_p.save(f"{FILES_ROOT}/{i}.png")
200
 
201
  with open(f"{FILES_ROOT}/input_text_decoded.txt", "w") as f:
202
  f.write(input_text_decoded)
@@ -218,36 +217,58 @@ def multimodal_understanding(model_type,
218
  def model_slider_change(model_type):
219
  global model_utils, vl_gpt, tokenizer, clip_utils, model_name, language_model_max_layer, language_model_best_layer, vision_model_best_layer
220
  model_name = model_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  if model_type == "Clip":
222
  clean()
223
  set_seed()
224
  clip_utils = Clip_Utils()
225
  clip_utils.init_Clip()
226
- res = (
227
- gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type"),
228
  gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
229
  gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max"),
230
- gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus"),
231
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
232
- )
233
- return res
234
  elif model_type.split('-')[0] == "Janus":
235
 
236
  clean()
237
  set_seed()
238
  model_utils = Janus_Utils()
239
  vl_gpt, tokenizer = model_utils.init_Janus(model_type.split('-')[-1])
 
 
 
240
  language_model_max_layer = 24
241
  language_model_best_layer = 8
242
-
243
- res = (
244
- gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
245
  gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
246
  gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max"),
247
- gr.Dropdown(choices=["Visual Encoder", "Language Model"], value="Visual Encoder", label="focus"),
248
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
249
- )
250
- return res
251
 
252
  elif model_type.split('-')[0] == "LLaVA":
253
 
@@ -259,32 +280,28 @@ def model_slider_change(model_type):
259
  language_model_max_layer = 32 if version == "1.5" else 28
260
  language_model_best_layer = 10
261
 
262
- res = (
263
- gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
264
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
265
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
266
- gr.Dropdown(choices=["Language Model"], value="Language Model", label="focus"),
267
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
268
- )
269
- return res
270
 
271
  elif model_type.split('-')[0] == "ChartGemma":
272
  clean()
273
  set_seed()
274
  model_utils = ChartGemma_Utils()
275
  vl_gpt, tokenizer = model_utils.init_ChartGemma()
 
 
276
  language_model_max_layer = 18
277
  vision_model_best_layer = 19
278
  language_model_best_layer = 15
279
 
280
- res = (
281
- gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
282
  gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
283
  gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
284
- gr.Dropdown(choices=["Visual Encoder", "Language Model"], value="Language Model", label="focus"),
285
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
286
- )
287
- return res
288
 
289
 
290
 
@@ -362,7 +379,8 @@ with gr.Blocks() as demo:
362
  response_type = gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type")
363
  focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
364
  activation_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="visualization type")
365
- visual_pooling_method = gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
 
366
 
367
 
368
  visualization_layers_min = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min")
@@ -377,10 +395,11 @@ with gr.Blocks() as demo:
377
  inputs=model_selector,
378
  outputs=[
379
  response_type,
380
- visualization_layers_min,
381
- visualization_layers_max,
382
  focus,
383
- activation_map_method
 
 
 
384
  ]
385
  )
386
 
@@ -492,7 +511,7 @@ with gr.Blocks() as demo:
492
 
493
  understanding_button.click(
494
  multimodal_understanding,
495
- inputs=[model_selector, activation_map_method, visual_pooling_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
496
  visualization_layers_min, visualization_layers_max, focus, response_type, chart_type],
497
  outputs=[understanding_output, activation_map_output, understanding_target_token_decoded_output]
498
  )
 
5
  from janus.utils.io import load_pil_images
6
  from demo.visualization import generate_gradcam, VisualizationJanus, VisualizationClip, VisualizationChartGemma, VisualizationLLaVA
7
  from demo.model_utils import Clip_Utils, Janus_Utils, LLaVA_Utils, ChartGemma_Utils, add_title_to_image
8
+ from demo.modified_attn import ModifiedLlamaAttention, ModifiedGemmaAttention
9
 
10
  import numpy as np
11
  import matplotlib.pyplot as plt
 
54
  @spaces.GPU(duration=120)
55
  def multimodal_understanding(model_type,
56
  activation_map_method,
57
+ visual_method,
58
  image, question, seed, top_p, temperature, target_token_idx,
59
  visualization_layer_min, visualization_layer_max, focus, response_type, chart_type):
60
  # Clear CUDA cache before generating
 
84
  else:
85
  target_layers = [all_layers[visualization_layer_min-1]]
86
  grad_cam = VisualizationClip(clip_utils.model, target_layers)
87
+ cam, outputs, grid_size = grad_cam.generate_cam(inputs, target_token_idx=0, visual_method=visual_method)
88
  cam = cam.to("cpu")
89
  cam = [generate_gradcam(cam, image, size=(224, 224))]
90
  grad_cam.remove_hooks()
 
145
  cam = []
146
  if focus == "Visual Encoder":
147
  if target_token_idx != -1:
148
+ cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_method, focus)
149
  cam_grid = cam_tensors.reshape(grid_size, grid_size)
150
  cam_i = generate_gradcam(cam_grid, image)
151
  cam_i = add_title_to_image(cam_i, input_ids_decoded[start + target_token_idx])
 
160
  gradcam = VisualizationLLaVA(vl_gpt, target_layers)
161
  elif model_name.split('-')[0] == "ChartGemma":
162
  gradcam = VisualizationChartGemma(vl_gpt, target_layers)
163
+ cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, i, visual_method, focus)
164
  cam_grid = cam_tensors.reshape(grid_size, grid_size)
165
  cam_i = generate_gradcam(cam_grid, image)
166
  cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
 
168
  gradcam.remove_hooks()
169
  i += 1
170
  else:
171
+ cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_method, focus)
172
  if target_token_idx != -1:
173
  input_text_decoded = input_ids_decoded[start + target_token_idx]
174
  for i, cam_tensor in enumerate(cam_tensors):
 
191
 
192
  # Collect Results
193
  RESULTS_ROOT = "./results"
194
+ FILES_ROOT = f"{RESULTS_ROOT}/{model_name}/{focus}/{visual_method}/{chart_type}/layer{visualization_layer_min}-{visualization_layer_max}/{'all_tokens' if target_token_idx == -1 else f'--{input_ids_decoded[start + target_token_idx]}--'}"
195
  os.makedirs(FILES_ROOT, exist_ok=True)
196
+
197
+ for i, cam_p in enumerate(cam):
198
+ cam_p.save(f"{FILES_ROOT}/{i}.png")
 
 
199
 
200
  with open(f"{FILES_ROOT}/input_text_decoded.txt", "w") as f:
201
  f.write(input_text_decoded)
 
217
  def model_slider_change(model_type):
218
  global model_utils, vl_gpt, tokenizer, clip_utils, model_name, language_model_max_layer, language_model_best_layer, vision_model_best_layer
219
  model_name = model_type
220
+
221
+
222
+ encoder_only_res = [
223
+ gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type"),
224
+ gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus"),
225
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
226
+ gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
227
+ ]
228
+
229
+ visual_res = [
230
+ gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="Visualization only", label="response_type"),
231
+ gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus"),
232
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
233
+ gr.Dropdown(choices=["softmax", "sigmoid"], value="softmax", label="activation function")
234
+ ]
235
+
236
+ language_res = [
237
+ gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
238
+ gr.Dropdown(choices=["Language Model"], value="Language Model", label="focus"),
239
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
240
+ gr.Dropdown(choices=["softmax", "sigmoid"], value="softmax", label="activation function")
241
+ ]
242
+
243
+
244
  if model_type == "Clip":
245
  clean()
246
  set_seed()
247
  clip_utils = Clip_Utils()
248
  clip_utils.init_Clip()
249
+ sliders = [
 
250
  gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
251
  gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max"),
252
+ ]
253
+ return tuple(encoder_only_res + sliders)
254
+
 
255
  elif model_type.split('-')[0] == "Janus":
256
 
257
  clean()
258
  set_seed()
259
  model_utils = Janus_Utils()
260
  vl_gpt, tokenizer = model_utils.init_Janus(model_type.split('-')[-1])
261
+ for layer in vl_gpt.language_model.model.layers:
262
+ layer.self_attn = ModifiedLlamaAttention(layer.self_attn)
263
+
264
  language_model_max_layer = 24
265
  language_model_best_layer = 8
266
+
267
+ sliders = [
 
268
  gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
269
  gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max"),
270
+ ]
271
+ return tuple(visual_res + sliders)
 
 
272
 
273
  elif model_type.split('-')[0] == "LLaVA":
274
 
 
280
  language_model_max_layer = 32 if version == "1.5" else 28
281
  language_model_best_layer = 10
282
 
283
+ sliders = [
 
284
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
285
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
286
+ ]
287
+ return tuple(language_res + sliders)
 
 
288
 
289
  elif model_type.split('-')[0] == "ChartGemma":
290
  clean()
291
  set_seed()
292
  model_utils = ChartGemma_Utils()
293
  vl_gpt, tokenizer = model_utils.init_ChartGemma()
294
+ for layer in vl_gpt.language_model.model.layers:
295
+ layer.self_attn = ModifiedGemmaAttention(layer.self_attn)
296
  language_model_max_layer = 18
297
  vision_model_best_layer = 19
298
  language_model_best_layer = 15
299
 
300
+ sliders = [
 
301
  gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
302
  gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
303
+ ]
304
+ return tuple(language_res + sliders)
 
 
305
 
306
 
307
 
 
379
  response_type = gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type")
380
  focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
381
  activation_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="visualization type")
382
+ # activation_function = gr.Dropdown(choices=["softmax", "sigmoid"], value="softmax", label="activation function")
383
+ visual_method = gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
384
 
385
 
386
  visualization_layers_min = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min")
 
395
  inputs=model_selector,
396
  outputs=[
397
  response_type,
 
 
398
  focus,
399
+ activation_map_method,
400
+ visual_method,
401
+ visualization_layers_min,
402
+ visualization_layers_max
403
  ]
404
  )
405
 
 
511
 
512
  understanding_button.click(
513
  multimodal_understanding,
514
+ inputs=[model_selector, activation_map_method, visual_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
515
  visualization_layers_min, visualization_layers_max, focus, response_type, chart_type],
516
  outputs=[understanding_output, activation_map_output, understanding_target_token_decoded_output]
517
  )
demo/model_utils.py CHANGED
@@ -204,7 +204,7 @@ class ChartGemma_Utils(Model_Utils):
204
  self.vl_gpt = PaliGemmaForConditionalGeneration.from_pretrained(
205
  model_path,
206
  torch_dtype=torch.float16,
207
- attn_implementation="sdpa",
208
  output_attentions=True
209
  )
210
  self.vl_gpt, self.dtype, self.cuda_device = set_dtype_device(self.vl_gpt)
 
204
  self.vl_gpt = PaliGemmaForConditionalGeneration.from_pretrained(
205
  model_path,
206
  torch_dtype=torch.float16,
207
+ attn_implementation="eager",
208
  output_attentions=True
209
  )
210
  self.vl_gpt, self.dtype, self.cuda_device = set_dtype_device(self.vl_gpt)
demo/modified_attn.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import typing
4
+ import typing_extensions
5
+ from torch import nn
6
+ from typing import Callable, List, Optional, Tuple, Union
7
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
8
+ from transformers.cache_utils import Cache
9
+ from transformers.models.llama.configuration_llama import LlamaConfig
10
+ from transformers.models.llama.modeling_llama import LlamaAttention
11
+ from transformers.models.gemma.modeling_gemma import GemmaAttention
12
+
13
+ # from transformers.models.paligemma.modeling_paligemma import
14
+
15
+ if sys.version_info >= (3, 11):
16
+ Unpack = typing.Unpack
17
+ else:
18
+ Unpack = typing_extensions.Unpack
19
+
20
+
21
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
22
+ """
23
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
24
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
25
+ """
26
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
27
+ if n_rep == 1:
28
+ return hidden_states
29
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
30
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
31
+
32
+ def rotate_half(x):
33
+ """Rotates half the hidden dims of the input."""
34
+ x1 = x[..., : x.shape[-1] // 2]
35
+ x2 = x[..., x.shape[-1] // 2 :]
36
+ return torch.cat((-x2, x1), dim=-1)
37
+
38
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
39
+ """Applies Rotary Position Embedding to the query and key tensors.
40
+
41
+ Args:
42
+ q (`torch.Tensor`): The query tensor.
43
+ k (`torch.Tensor`): The key tensor.
44
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
45
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
46
+ position_ids (`torch.Tensor`, *optional*):
47
+ Deprecated and unused.
48
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
49
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
50
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
51
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
52
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
53
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
54
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
55
+ Returns:
56
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
57
+ """
58
+ cos = cos.unsqueeze(unsqueeze_dim)
59
+ sin = sin.unsqueeze(unsqueeze_dim)
60
+ q_embed = (q * cos) + (rotate_half(q) * sin)
61
+ k_embed = (k * cos) + (rotate_half(k) * sin)
62
+ return q_embed, k_embed
63
+
64
+ def eager_attention_forward(
65
+ module: nn.Module,
66
+ query: torch.Tensor,
67
+ key: torch.Tensor,
68
+ value: torch.Tensor,
69
+ attention_mask: Optional[torch.Tensor],
70
+ scaling: float,
71
+ dropout: float = 0.0,
72
+ **kwargs,
73
+ ):
74
+ key_states = repeat_kv(key, module.num_key_value_groups)
75
+ value_states = repeat_kv(value, module.num_key_value_groups)
76
+
77
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
78
+ if attention_mask is not None:
79
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
80
+ attn_weights = attn_weights + causal_mask
81
+
82
+ attn_sigmoid_weights = nn.functional.sigmoid(attn_weights).to(query.dtype)
83
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
84
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
85
+ attn_output = torch.matmul(attn_weights, value_states)
86
+ attn_output = attn_output.transpose(1, 2).contiguous()
87
+
88
+ return attn_output, attn_weights, attn_sigmoid_weights
89
+
90
+
91
+ class ModifiedLlamaAttention(nn.Module):
92
+ def __init__(self, llama_attention_old: LlamaAttention):
93
+ super().__init__()
94
+ self.config = llama_attention_old.config
95
+ self.layer_idx = llama_attention_old.layer_idx
96
+ self.head_dim = llama_attention_old.head_dim
97
+ self.num_key_value_groups = llama_attention_old.num_key_value_groups
98
+ self.scaling = self.head_dim**-0.5
99
+ self.attention_dropout = llama_attention_old.attention_dropout
100
+ self.is_causal = True
101
+
102
+ self.q_proj = llama_attention_old.q_proj
103
+ self.k_proj = llama_attention_old.k_proj
104
+ self.v_proj = llama_attention_old.v_proj
105
+ self.o_proj = llama_attention_old.o_proj
106
+
107
+
108
+ def forward(
109
+ self,
110
+ hidden_states: torch.Tensor,
111
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
112
+ attention_mask: Optional[torch.Tensor],
113
+ past_key_value: Optional[Cache] = None,
114
+ cache_position: Optional[torch.LongTensor] = None,
115
+ **kwargs: Unpack[FlashAttentionKwargs],
116
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
117
+ input_shape = hidden_states.shape[:-1]
118
+ hidden_shape = (*input_shape, -1, self.head_dim)
119
+
120
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
121
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
122
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
123
+
124
+ cos, sin = position_embeddings
125
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
126
+
127
+ if past_key_value is not None:
128
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
129
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
130
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
131
+
132
+ attention_interface: Callable = eager_attention_forward
133
+
134
+
135
+ attn_output, attn_weights, attn_sigmoid_weights = attention_interface(
136
+ self,
137
+ query_states,
138
+ key_states,
139
+ value_states,
140
+ attention_mask,
141
+ dropout=0.0 if not self.training else self.attention_dropout,
142
+ scaling=self.scaling,
143
+ **kwargs,
144
+ )
145
+
146
+ self.attn_sigmoid_weights = attn_sigmoid_weights
147
+
148
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
149
+ attn_output = self.o_proj(attn_output)
150
+ return attn_output, attn_weights
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+ class ModifiedGemmaAttention(nn.Module):
162
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
163
+
164
+ def __init__(self, gemma_attention_old: GemmaAttention):
165
+ super().__init__()
166
+ self.config = gemma_attention_old.config
167
+ self.layer_idx = gemma_attention_old.layer_idx
168
+ self.head_dim = gemma_attention_old.head_dim
169
+ self.num_key_value_groups = gemma_attention_old.num_key_value_groups
170
+ self.scaling = gemma_attention_old.scaling
171
+ self.attention_dropout = gemma_attention_old.attention_dropout
172
+ self.is_causal = True
173
+
174
+ self.q_proj = gemma_attention_old.q_proj
175
+ self.k_proj = gemma_attention_old.k_proj
176
+ self.v_proj = gemma_attention_old.v_proj
177
+ self.o_proj = gemma_attention_old.o_proj
178
+
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.Tensor,
182
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
183
+ attention_mask: Optional[torch.Tensor],
184
+ past_key_value: Optional[Cache] = None,
185
+ cache_position: Optional[torch.LongTensor] = None,
186
+ **kwargs: Unpack[FlashAttentionKwargs],
187
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
188
+ input_shape = hidden_states.shape[:-1]
189
+ hidden_shape = (*input_shape, -1, self.head_dim)
190
+
191
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
192
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
193
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
194
+
195
+ cos, sin = position_embeddings
196
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
197
+
198
+ if past_key_value is not None:
199
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
200
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
201
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
202
+
203
+ attention_interface: Callable = eager_attention_forward
204
+
205
+ attn_output, attn_weights, attn_sigmoid_weights = attention_interface(
206
+ self,
207
+ query_states,
208
+ key_states,
209
+ value_states,
210
+ attention_mask,
211
+ dropout=0.0 if not self.training else self.attention_dropout,
212
+ scaling=self.scaling,
213
+ **kwargs,
214
+ )
215
+
216
+ self.attn_sigmoid_weights = attn_sigmoid_weights
217
+
218
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
219
+ attn_output = self.o_proj(attn_output)
220
+ return attn_output, attn_weights
221
+
demo/visualization.py CHANGED
@@ -8,7 +8,7 @@ from PIL import Image
8
  from torch import nn
9
  import spaces
10
  from demo.modify_llama import *
11
-
12
 
13
  class Visualization:
14
  def __init__(self, model, register=True):
@@ -25,6 +25,7 @@ class Visualization:
25
  self.hooks.append(layer.register_backward_hook(self._backward_hook))
26
 
27
  def _forward_hook(self, module, input, output):
 
28
  self.activations.append(output)
29
 
30
  def _backward_hook(self, module, grad_in, grad_out):
@@ -41,6 +42,9 @@ class Visualization:
41
  layer.get_attn_map = types.MethodType(get_attn_map, layer)
42
 
43
  def _forward_activate_hooks(self, module, input, output):
 
 
 
44
  attn_output, attn_weights = output # Unpack outputs
45
  print("attn_output shape:", attn_output.shape)
46
  print("attn_weights shape:", attn_weights.shape)
@@ -231,15 +235,15 @@ class VisualizationClip(Visualization):
231
  super().__init__(model)
232
 
233
  @spaces.GPU(duration=120)
234
- def forward_backward(self, input_tensor, visual_pooling_method, target_token_idx):
235
  output_full = self.model(**input_tensor)
236
 
237
  if target_token_idx is None:
238
  target_token_idx = torch.argmax(output_full.logits, dim=1).item()
239
 
240
- if visual_pooling_method == "CLS":
241
  output = output_full.image_embeds
242
- elif visual_pooling_method == "avg":
243
  output = self.model.visual_projection(output_full.vision_model_output.last_hidden_state).mean(dim=1)
244
  else:
245
  output, _ = self.model.visual_projection(output_full.vision_model_output.last_hidden_state).max(dim=1)
@@ -250,11 +254,11 @@ class VisualizationClip(Visualization):
250
 
251
 
252
  @spaces.GPU(duration=120)
253
- def generate_cam(self, input_tensor, target_token_idx=None, visual_pooling_method="CLS"):
254
  """ Generates Grad-CAM heatmap for ViT. """
255
  self.setup_grads()
256
  # Forward Backward pass
257
- output_full = self.forward_backward(input_tensor, visual_pooling_method, target_token_idx)
258
 
259
  cam_sum = self.grad_cam_vis()
260
  cam_sum, grid_size = self.process(cam_sum)
@@ -291,34 +295,33 @@ class VisualizationJanus(Visualization):
291
  self._modify_layers()
292
  self._register_hooks_activations()
293
 
294
- def forward_backward(self, input_tensor, tokenizer, temperature, top_p, target_token_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
295
  # Forward
296
  image_embeddings, inputs_embeddings, outputs = self.model(input_tensor, tokenizer, temperature, top_p)
297
  input_ids = input_tensor.input_ids
298
-
 
299
  if focus == "Visual Encoder":
300
-
301
- start_idx = 620
302
- self.model.zero_grad()
303
-
304
  loss = outputs.logits.max(dim=-1).values[0, start_idx + target_token_idx]
305
  loss.backward()
306
 
307
  elif focus == "Language Model":
308
- self.model.zero_grad()
309
- loss = outputs.logits.max(dim=-1).values.sum()
 
 
310
  loss.backward()
311
 
312
- self.activations = [layer.get_attn_map() for layer in self.target_layers]
313
  self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
314
 
315
  @spaces.GPU(duration=120)
316
- def generate_cam(self, input_tensor, tokenizer, temperature, top_p, target_token_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
317
 
318
  self.setup_grads()
319
 
320
  # Forward Backward pass
321
- self.forward_backward(input_tensor, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
322
 
323
  start_idx = 620
324
  if focus == "Visual Encoder":
@@ -365,7 +368,7 @@ class VisualizationLLaVA(Visualization):
365
  self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
366
 
367
  @spaces.GPU(duration=120)
368
- def generate_cam(self, inputs, tokenizer, temperature, top_p, target_token_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
369
 
370
  self.setup_grads()
371
  self.forward_backward(inputs)
@@ -401,7 +404,7 @@ class VisualizationChartGemma(Visualization):
401
  self._modify_layers()
402
  self._register_hooks_activations()
403
 
404
- def forward_backward(self, inputs, focus, start_idx, target_token_idx):
405
  outputs_raw = self.model(**inputs, output_hidden_states=True)
406
  if focus == "Visual Encoder":
407
 
@@ -417,11 +420,11 @@ class VisualizationChartGemma(Visualization):
417
  else:
418
  loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
419
  loss.backward()
420
- self.activations = [layer.get_attn_map() for layer in self.target_layers]
421
  self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
422
 
423
  @spaces.GPU(duration=120)
424
- def generate_cam(self, inputs, tokenizer, temperature, top_p, target_token_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
425
 
426
  # Forward pass
427
  self.setup_grads()
@@ -439,7 +442,7 @@ class VisualizationChartGemma(Visualization):
439
  start_idx = last + 1
440
 
441
 
442
- self.forward_backward(inputs, focus, start_idx, target_token_idx)
443
  if focus == "Visual Encoder":
444
 
445
  cam_sum = self.grad_cam_vis()
 
8
  from torch import nn
9
  import spaces
10
  from demo.modify_llama import *
11
+ from demo.modified_attn import ModifiedLlamaAttention
12
 
13
  class Visualization:
14
  def __init__(self, model, register=True):
 
25
  self.hooks.append(layer.register_backward_hook(self._backward_hook))
26
 
27
  def _forward_hook(self, module, input, output):
28
+ print("forward_hook: self_attn_input: ", input)
29
  self.activations.append(output)
30
 
31
  def _backward_hook(self, module, grad_in, grad_out):
 
42
  layer.get_attn_map = types.MethodType(get_attn_map, layer)
43
 
44
  def _forward_activate_hooks(self, module, input, output):
45
+ print("forward_activate_hool: module: ", module)
46
+ print("forward_activate_hook: self_attn_input: ", input)
47
+
48
  attn_output, attn_weights = output # Unpack outputs
49
  print("attn_output shape:", attn_output.shape)
50
  print("attn_weights shape:", attn_weights.shape)
 
235
  super().__init__(model)
236
 
237
  @spaces.GPU(duration=120)
238
+ def forward_backward(self, input_tensor, visual_method, target_token_idx):
239
  output_full = self.model(**input_tensor)
240
 
241
  if target_token_idx is None:
242
  target_token_idx = torch.argmax(output_full.logits, dim=1).item()
243
 
244
+ if visual_method == "CLS":
245
  output = output_full.image_embeds
246
+ elif visual_method == "avg":
247
  output = self.model.visual_projection(output_full.vision_model_output.last_hidden_state).mean(dim=1)
248
  else:
249
  output, _ = self.model.visual_projection(output_full.vision_model_output.last_hidden_state).max(dim=1)
 
254
 
255
 
256
  @spaces.GPU(duration=120)
257
+ def generate_cam(self, input_tensor, target_token_idx=None, visual_method="CLS"):
258
  """ Generates Grad-CAM heatmap for ViT. """
259
  self.setup_grads()
260
  # Forward Backward pass
261
+ output_full = self.forward_backward(input_tensor, visual_method, target_token_idx)
262
 
263
  cam_sum = self.grad_cam_vis()
264
  cam_sum, grid_size = self.process(cam_sum)
 
295
  self._modify_layers()
296
  self._register_hooks_activations()
297
 
298
+ def forward_backward(self, input_tensor, tokenizer, temperature, top_p, target_token_idx=None, visual_method="softmax", focus="Visual Encoder"):
299
  # Forward
300
  image_embeddings, inputs_embeddings, outputs = self.model(input_tensor, tokenizer, temperature, top_p)
301
  input_ids = input_tensor.input_ids
302
+ start_idx = 620
303
+ self.model.zero_grad()
304
  if focus == "Visual Encoder":
 
 
 
 
305
  loss = outputs.logits.max(dim=-1).values[0, start_idx + target_token_idx]
306
  loss.backward()
307
 
308
  elif focus == "Language Model":
309
+ if target_token_idx == -1:
310
+ loss = outputs.logits.max(dim=-1).values.sum()
311
+ else:
312
+ loss = outputs.logits.max(dim=-1).values[0, start_idx + target_token_idx]
313
  loss.backward()
314
 
315
+ self.activations = self.activations = [layer.attn_sigmoid_weights for layer in self.target_layers] if visual_method == "sigmoid" else [layer.get_attn_map() for layer in self.target_layers]
316
  self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
317
 
318
  @spaces.GPU(duration=120)
319
+ def generate_cam(self, input_tensor, tokenizer, temperature, top_p, target_token_idx=None, visual_method="softmax", focus="Visual Encoder"):
320
 
321
  self.setup_grads()
322
 
323
  # Forward Backward pass
324
+ self.forward_backward(input_tensor, tokenizer, temperature, top_p, target_token_idx, visual_method, focus)
325
 
326
  start_idx = 620
327
  if focus == "Visual Encoder":
 
368
  self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
369
 
370
  @spaces.GPU(duration=120)
371
+ def generate_cam(self, inputs, tokenizer, temperature, top_p, target_token_idx=None, visual_method="softmax", focus="Visual Encoder"):
372
 
373
  self.setup_grads()
374
  self.forward_backward(inputs)
 
404
  self._modify_layers()
405
  self._register_hooks_activations()
406
 
407
+ def forward_backward(self, inputs, focus, start_idx, target_token_idx, visual_method="softmax"):
408
  outputs_raw = self.model(**inputs, output_hidden_states=True)
409
  if focus == "Visual Encoder":
410
 
 
420
  else:
421
  loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
422
  loss.backward()
423
+ self.activations = [layer.attn_sigmoid_weights for layer in self.target_layers] if visual_method == "sigmoid" else [layer.get_attn_map() for layer in self.target_layers]
424
  self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
425
 
426
  @spaces.GPU(duration=120)
427
+ def generate_cam(self, inputs, tokenizer, temperature, top_p, target_token_idx=None, visual_method="softmax", focus="Visual Encoder"):
428
 
429
  # Forward pass
430
  self.setup_grads()
 
442
  start_idx = last + 1
443
 
444
 
445
+ self.forward_backward(inputs, focus, start_idx, target_token_idx, visual_method)
446
  if focus == "Visual Encoder":
447
 
448
  cam_sum = self.grad_cam_vis()