AustingDong commited on
Commit
8235fd2
·
1 Parent(s): 4db7aa5

modified visual encoder

Browse files
Files changed (2) hide show
  1. app.py +70 -28
  2. demo/cam.py +137 -61
app.py CHANGED
@@ -25,6 +25,7 @@ model_utils, vl_gpt, tokenizer = None, None, None
25
  model_name = "Clip"
26
  language_model_max_layer = 24
27
  language_model_best_layer = 8
 
28
 
29
  def clean():
30
  global model_utils, vl_gpt, tokenizer, clip_utils
@@ -116,7 +117,10 @@ def multimodal_understanding(model_type,
116
  if activation_map_method == "GradCAM":
117
  # target_layers = vl_gpt.vision_model.vision_tower.blocks
118
  if focus == "Visual Encoder":
119
- all_layers = [block.norm1 for block in vl_gpt.vision_model.vision_tower.blocks]
 
 
 
120
  else:
121
  all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
122
 
@@ -137,17 +141,33 @@ def multimodal_understanding(model_type,
137
  gradcam = AttentionGuidedCAMChartGemma(vl_gpt, target_layers)
138
 
139
  start = 0
 
140
  if focus == "Visual Encoder":
141
- cam_tensors, grid_size = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  else:
143
  cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
144
- gradcam.remove_hooks()
145
-
146
-
147
- if focus == "Visual Encoder":
148
- cam_grid = cam_tensors.reshape(grid_size, grid_size)
149
- cam = [generate_gradcam(cam_grid, image)]
150
- else:
151
  if target_token_idx != -1:
152
  input_text_decoded = input_ids_decoded[start + target_token_idx]
153
  for i, cam_tensor in enumerate(cam_tensors):
@@ -164,6 +184,9 @@ def multimodal_understanding(model_type,
164
  cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
165
 
166
  cam.append(cam_i)
 
 
 
167
 
168
  # Collect Results
169
  RESULTS_ROOT = "./results"
@@ -193,7 +216,7 @@ def multimodal_understanding(model_type,
193
  # Gradio interface
194
 
195
  def model_slider_change(model_type):
196
- global model_utils, vl_gpt, tokenizer, clip_utils, model_name, language_model_max_layer, language_model_best_layer
197
  model_name = model_type
198
  if model_type == "Clip":
199
  clean()
@@ -251,13 +274,14 @@ def model_slider_change(model_type):
251
  model_utils = ChartGemma_Utils()
252
  vl_gpt, tokenizer = model_utils.init_ChartGemma()
253
  language_model_max_layer = 18
 
254
  language_model_best_layer = 15
255
 
256
  res = (
257
  gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
258
  gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
259
  gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
260
- gr.Dropdown(choices=["Language Model"], value="Language Model", label="focus"),
261
  gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
262
  )
263
  return res
@@ -292,12 +316,21 @@ def focus_change(focus):
292
  return res
293
 
294
  else:
295
- res = (
296
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
297
- gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
298
- gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max")
299
- )
300
- return res
 
 
 
 
 
 
 
 
 
301
 
302
 
303
 
@@ -305,27 +338,37 @@ def focus_change(focus):
305
 
306
  with gr.Blocks() as demo:
307
  gr.Markdown(value="# Multimodal Understanding")
 
 
 
 
 
 
 
 
 
308
  with gr.Row():
309
- with gr.Column():
310
- image_input = gr.Image()
311
- activation_map_output = gr.Gallery(label="activation Map", height=300, columns=1)
312
 
313
  with gr.Column():
314
  model_selector = gr.Dropdown(choices=["Clip", "ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="Clip", label="model")
 
 
 
 
 
 
 
 
315
  response_type = gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type")
316
  focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
317
- activation_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
318
  visual_pooling_method = gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
319
 
320
 
321
  visualization_layers_min = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min")
322
  visualization_layers_max = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
 
323
 
324
- question_input = gr.Textbox(label="Question")
325
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
326
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
327
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
328
- target_token_idx = gr.Number(label="target_token_idx (-1 means all)", precision=0, value=-1)
329
 
330
 
331
 
@@ -360,8 +403,7 @@ with gr.Blocks() as demo:
360
 
361
 
362
  understanding_button = gr.Button("Submit")
363
- chart_type = gr.Textbox(label="Chart Type")
364
- understanding_output = gr.Textbox(label="Answer")
365
  understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
366
 
367
 
 
25
  model_name = "Clip"
26
  language_model_max_layer = 24
27
  language_model_best_layer = 8
28
+ vision_model_best_layer = 24
29
 
30
  def clean():
31
  global model_utils, vl_gpt, tokenizer, clip_utils
 
117
  if activation_map_method == "GradCAM":
118
  # target_layers = vl_gpt.vision_model.vision_tower.blocks
119
  if focus == "Visual Encoder":
120
+ if model_name.split('-')[0] == "Janus":
121
+ all_layers = [block.norm1 for block in vl_gpt.vision_model.vision_tower.blocks]
122
+ else:
123
+ all_layers = [block.layer_norm1 for block in vl_gpt.vision_tower.vision_model.encoder.layers]
124
  else:
125
  all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
126
 
 
141
  gradcam = AttentionGuidedCAMChartGemma(vl_gpt, target_layers)
142
 
143
  start = 0
144
+ cam = []
145
  if focus == "Visual Encoder":
146
+ if target_token_idx != -1:
147
+ cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
148
+ cam_grid = cam_tensors.reshape(grid_size, grid_size)
149
+ cam_i = generate_gradcam(cam_grid, image)
150
+ cam_i = add_title_to_image(cam_i, input_ids_decoded[start + target_token_idx])
151
+ cam = [cam_i]
152
+ else:
153
+ i = 0
154
+ cam = []
155
+ while start + i < len(input_ids_decoded):
156
+ if model_name.split('-')[0] == "Janus":
157
+ gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
158
+ elif model_name.split('-')[0] == "LLaVA":
159
+ gradcam = AttentionGuidedCAMLLaVA(vl_gpt, target_layers)
160
+ elif model_name.split('-')[0] == "ChartGemma":
161
+ gradcam = AttentionGuidedCAMChartGemma(vl_gpt, target_layers)
162
+ cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, i, visual_pooling_method, focus)
163
+ cam_grid = cam_tensors.reshape(grid_size, grid_size)
164
+ cam_i = generate_gradcam(cam_grid, image)
165
+ cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
166
+ cam.append(cam_i)
167
+ gradcam.remove_hooks()
168
+ i += 1
169
  else:
170
  cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
 
 
 
 
 
 
 
171
  if target_token_idx != -1:
172
  input_text_decoded = input_ids_decoded[start + target_token_idx]
173
  for i, cam_tensor in enumerate(cam_tensors):
 
184
  cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
185
 
186
  cam.append(cam_i)
187
+
188
+ gradcam.remove_hooks()
189
+
190
 
191
  # Collect Results
192
  RESULTS_ROOT = "./results"
 
216
  # Gradio interface
217
 
218
  def model_slider_change(model_type):
219
+ global model_utils, vl_gpt, tokenizer, clip_utils, model_name, language_model_max_layer, language_model_best_layer, vision_model_best_layer
220
  model_name = model_type
221
  if model_type == "Clip":
222
  clean()
 
274
  model_utils = ChartGemma_Utils()
275
  vl_gpt, tokenizer = model_utils.init_ChartGemma()
276
  language_model_max_layer = 18
277
+ vision_model_best_layer = 19
278
  language_model_best_layer = 15
279
 
280
  res = (
281
  gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
282
  gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
283
  gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
284
+ gr.Dropdown(choices=["Visual Encoder", "Language Model"], value="Language Model", label="focus"),
285
  gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
286
  )
287
  return res
 
316
  return res
317
 
318
  else:
319
+ if model_name.split('-')[0] == "ChartGemma":
320
+ res = (
321
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
322
+ gr.Slider(minimum=1, maximum=26, value=vision_model_best_layer, step=1, label="visualization layers min"),
323
+ gr.Slider(minimum=1, maximum=26, value=vision_model_best_layer, step=1, label="visualization layers max")
324
+ )
325
+ return res
326
+
327
+ else:
328
+ res = (
329
+ gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
330
+ gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
331
+ gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max")
332
+ )
333
+ return res
334
 
335
 
336
 
 
338
 
339
  with gr.Blocks() as demo:
340
  gr.Markdown(value="# Multimodal Understanding")
341
+
342
+ with gr.Row():
343
+ image_input = gr.Image(height=500, label="Image")
344
+ activation_map_output = gr.Gallery(label="Visualization", height=500, columns=1, preview=True)
345
+
346
+ with gr.Row():
347
+ chart_type = gr.Textbox(label="Chart Type")
348
+ understanding_output = gr.Textbox(label="Answer")
349
+
350
  with gr.Row():
 
 
 
351
 
352
  with gr.Column():
353
  model_selector = gr.Dropdown(choices=["Clip", "ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="Clip", label="model")
354
+ question_input = gr.Textbox(label="Input Prompt")
355
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
356
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
357
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
358
+ target_token_idx = gr.Number(label="target_token_idx (-1 means all)", precision=0, value=-1)
359
+
360
+
361
+ with gr.Column():
362
  response_type = gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type")
363
  focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
364
+ activation_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="visualization type")
365
  visual_pooling_method = gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
366
 
367
 
368
  visualization_layers_min = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min")
369
  visualization_layers_max = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
370
+
371
 
 
 
 
 
 
372
 
373
 
374
 
 
403
 
404
 
405
  understanding_button = gr.Button("Submit")
406
+
 
407
  understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
408
 
409
 
demo/cam.py CHANGED
@@ -85,10 +85,11 @@ class AttentionGuidedCAMClip(AttentionGuidedCAM):
85
  print("act shape", act.shape)
86
  print("grad_weights shape", grad_weights.shape)
87
 
88
- # cam = (act * grad_weights).sum(dim=-1) # Weighted activation map
89
  cam, _ = (act * grad_weights).max(dim=-1)
 
 
90
  # cam, _ = grad_weights.max(dim=-1)
91
- # cam = self.normalize(cam)
92
  print("cam_shape: ", cam.shape)
93
 
94
  # Sum across all layers
@@ -166,20 +167,23 @@ class AttentionGuidedCAMJanus(AttentionGuidedCAM):
166
 
167
  if focus == "Visual Encoder":
168
  # Pooling
169
- if visual_pooling_method == "CLS":
170
- image_embeddings_pooled = image_embeddings[:, 0, :]
171
- elif visual_pooling_method == "avg":
172
- image_embeddings_pooled = image_embeddings[:, 1:, :].mean(dim=1) # end of image: 618
173
- elif visual_pooling_method == "max":
174
- image_embeddings_pooled, _ = image_embeddings[:, 1:, :].max(dim=1)
175
-
176
- print("image_embeddings_shape: ", image_embeddings_pooled.shape)
177
-
178
 
 
 
179
 
180
- inputs_embeddings_pooled = inputs_embeddings[:, 620: -4].mean(dim=1)
 
181
  self.model.zero_grad()
182
- image_embeddings_pooled.backward(inputs_embeddings_pooled, retain_graph=True)
 
 
 
183
 
184
  cam_sum = None
185
  for act, grad in zip(self.activations, self.gradients):
@@ -195,6 +199,7 @@ class AttentionGuidedCAMJanus(AttentionGuidedCAM):
195
  print("grad_weights shape", grad_weights.shape)
196
 
197
  cam, _ = (act * grad_weights).max(dim=-1)
 
198
  print(cam.shape)
199
 
200
  # Sum across all layers
@@ -224,7 +229,7 @@ class AttentionGuidedCAMJanus(AttentionGuidedCAM):
224
  cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
225
  cam_sum = cam_sum.detach().to("cpu")
226
 
227
- return cam_sum, grid_size
228
 
229
 
230
 
@@ -407,7 +412,7 @@ class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
407
  class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
408
  def __init__(self, model, target_layers):
409
  self.target_layers = target_layers
410
- super().__init__(model, register=False)
411
  self._modify_layers()
412
  self._register_hooks_activations()
413
 
@@ -445,12 +450,9 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
445
  for param in layer.parameters():
446
  param.requires_grad = True
447
 
448
- outputs_raw = self.model(**inputs)
 
449
 
450
- self.model.zero_grad()
451
- # print(outputs_raw)
452
- loss = outputs_raw.logits.max(dim=-1).values.sum()
453
- loss.backward()
454
 
455
  # get image masks
456
  image_mask = []
@@ -462,61 +464,135 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
462
  last = i
463
  else:
464
  image_mask.append(False)
 
465
 
466
 
467
- # Aggregate activations and gradients from ALL layers
468
- self.activations = [layer.get_attn_map() for layer in self.target_layers]
469
- self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
470
- print(f"layers shape: {len(self.target_layers)}")
471
- print("activations & gradients shape", len(self.activations), len(self.gradients))
472
 
473
- cams = []
474
-
475
- # Ver 2
476
- for act, grad in zip(self.activations, self.gradients):
477
-
478
- print("act shape", act.shape)
479
- print("grad shape", grad.shape)
 
 
 
480
 
481
- grad = F.relu(grad)
482
 
483
- cam = act * grad # shape: [1, heads, seq_len, seq_len]
484
- cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
485
- cam = cam.to(torch.float32).detach().cpu()
486
- cams.append(cam)
487
 
488
- # cam_sum = F.relu(cam_sum)
489
- # cam_sum = cam_sum.to(torch.float32)
 
490
 
491
- # cams shape: [layers, 1, seq_len, seq_len]
492
- cam_sum_lst = []
493
 
494
- start_idx = last + 1
495
- for i in range(start_idx, cams[0].shape[1]):
496
  cam_sum = None
497
- for layer, cam_l in enumerate(cams):
498
- cam_l_i = cam_l[0, i, :] # shape: [1: seq_len]
 
 
499
 
500
- cam_l_i = cam_l_i[image_mask].unsqueeze(0) # shape: [1, img_seq_len]
501
- # print(f"layer: {layer}, token index: {i}")
502
- # print("cam_sum shape: ", cam_l_i.shape)
503
- num_patches = cam_l_i.shape[-1] # Last dimension of CAM output
504
- grid_size = int(num_patches ** 0.5)
505
- # print(f"Detected grid size: {grid_size}x{grid_size}")
506
 
507
- # Fix the reshaping step dynamically
508
- cam_reshaped = cam_l_i.view(grid_size, grid_size)
509
- # print(f"max: {cam_reshaped.max()}, min: {cam_reshaped.min()}")
510
- cam_normalized = (cam_reshaped - cam_reshaped.min()) / (cam_reshaped.max() - cam_reshaped.min())
511
- if cam_sum == None:
512
- cam_sum = cam_normalized
 
 
 
 
 
513
  else:
514
- cam_sum += cam_normalized
515
- # print(f"normalized: max: {cam_normalized.max()}, min: {cam_normalized.min()}")
 
 
 
 
 
 
 
 
516
 
517
- # print(f"sum: max: {cam_sum.max()}, min: {cam_sum.min()}")
 
 
 
 
 
 
518
  cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
519
- cam_sum_lst.append(cam_sum)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
 
522
  return cam_sum_lst, grid_size, start_idx
 
85
  print("act shape", act.shape)
86
  print("grad_weights shape", grad_weights.shape)
87
 
88
+ # cam = (act * grad_weights).sum(dim=-1)
89
  cam, _ = (act * grad_weights).max(dim=-1)
90
+ # cam, _ = act.max(dim=-1)
91
+ # cam = cam.unsqueeze(0)
92
  # cam, _ = grad_weights.max(dim=-1)
 
93
  print("cam_shape: ", cam.shape)
94
 
95
  # Sum across all layers
 
167
 
168
  if focus == "Visual Encoder":
169
  # Pooling
170
+ # if visual_pooling_method == "CLS":
171
+ # image_embeddings_pooled = image_embeddings[:, 0, :]
172
+ # elif visual_pooling_method == "avg":
173
+ # image_embeddings_pooled = image_embeddings[:, 1:, :].mean(dim=1)
174
+ # elif visual_pooling_method == "max":
175
+ # image_embeddings_pooled, _ = image_embeddings[:, 1:, :].max(dim=1)
 
 
 
176
 
177
+ # print("image_embeddings_shape: ", image_embeddings_pooled.shape)
178
+
179
 
180
+ start_idx = 620
181
+ # inputs_embeddings_pooled = inputs_embeddings[:, start_idx: -4].mean(dim=1)
182
  self.model.zero_grad()
183
+ # image_embeddings_pooled.backward(inputs_embeddings_pooled, retain_graph=True)
184
+
185
+ loss = outputs.logits.max(dim=-1).values[0, start_idx + class_idx]
186
+ loss.backward()
187
 
188
  cam_sum = None
189
  for act, grad in zip(self.activations, self.gradients):
 
199
  print("grad_weights shape", grad_weights.shape)
200
 
201
  cam, _ = (act * grad_weights).max(dim=-1)
202
+ # cam, _ = grad_weights.max(dim=-1)
203
  print(cam.shape)
204
 
205
  # Sum across all layers
 
229
  cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
230
  cam_sum = cam_sum.detach().to("cpu")
231
 
232
+ return cam_sum, grid_size, start_idx
233
 
234
 
235
 
 
412
  class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
413
  def __init__(self, model, target_layers):
414
  self.target_layers = target_layers
415
+ super().__init__(model, register=True)
416
  self._modify_layers()
417
  self._register_hooks_activations()
418
 
 
450
  for param in layer.parameters():
451
  param.requires_grad = True
452
 
453
+ outputs_raw = self.model(**inputs, output_hidden_states=True)
454
+
455
 
 
 
 
 
456
 
457
  # get image masks
458
  image_mask = []
 
464
  last = i
465
  else:
466
  image_mask.append(False)
467
+ start_idx = last + 1
468
 
469
 
 
 
 
 
 
470
 
471
+ if focus == "Visual Encoder":
472
+ # image_embeddings = outputs_raw.image_hidden_states
473
+ # inputs_embeddings = outputs_raw.hidden_states[0]
474
+ # # Pooling
475
+ # if visual_pooling_method == "avg":
476
+ # image_embeddings_pooled = image_embeddings.mean(dim=1) # end of image: 618
477
+ # elif visual_pooling_method == "max":
478
+ # image_embeddings_pooled, _ = image_embeddings.max(dim=1)
479
+
480
+ # print("image_embeddings_shape: ", image_embeddings_pooled.shape)
481
 
 
482
 
 
 
 
 
483
 
484
+ # inputs_embeddings_pooled = inputs_embeddings[:, start_idx:].mean(dim=1)
485
+ self.model.zero_grad()
486
+ # image_embeddings_pooled.backward(inputs_embeddings_pooled, retain_graph=True)
487
 
488
+ loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + class_idx]
489
+ loss.backward()
490
 
 
 
491
  cam_sum = None
492
+ for act, grad in zip(self.activations, self.gradients):
493
+ # act = torch.sigmoid(act)
494
+ act = F.relu(act[0])
495
+
496
 
497
+ # Compute mean of gradients
498
+ print("grad shape:", grad.shape)
499
+ grad_weights = grad.mean(dim=-1, keepdim=True)
 
 
 
500
 
501
+ print("act shape", act.shape)
502
+ print("grad_weights shape", grad_weights.shape)
503
+
504
+ cam = (act * grad_weights).sum(dim=-1)
505
+ # cam, _ = (act * grad_weights).max(dim=-1)
506
+ # cam, _ = grad_weights.max(dim=-1)
507
+ print(cam.shape)
508
+
509
+ # Sum across all layers
510
+ if cam_sum is None:
511
+ cam_sum = cam
512
  else:
513
+ cam_sum += cam
514
+
515
+ # Normalize
516
+ cam_sum = F.relu(cam_sum)
517
+
518
+
519
+ # thresholding
520
+ cam_sum = cam_sum.to(torch.float32).detach().cpu()
521
+ percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
522
+ cam_sum[cam_sum < percentile] = 0
523
 
524
+ # Reshape
525
+ print("cam_sum shape: ", cam_sum.shape)
526
+ num_patches = cam_sum.shape[-1] # Last dimension of CAM output
527
+ grid_size = int(num_patches ** 0.5)
528
+ print(f"Detected grid size: {grid_size}x{grid_size}")
529
+
530
+ cam_sum = cam_sum.view(grid_size, grid_size)
531
  cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
532
+
533
+ return cam_sum, grid_size, start_idx
534
+
535
+ elif focus == "Language Model":
536
+ self.model.zero_grad()
537
+ # print(outputs_raw)
538
+ loss = outputs_raw.logits.max(dim=-1).values.sum()
539
+ loss.backward()
540
+
541
+
542
+
543
+ # Aggregate activations and gradients from ALL layers
544
+ self.activations = [layer.get_attn_map() for layer in self.target_layers]
545
+ self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
546
+ print(f"layers shape: {len(self.target_layers)}")
547
+ print("activations & gradients shape", len(self.activations), len(self.gradients))
548
+
549
+ cams = []
550
+
551
+ # Ver 2
552
+ for act, grad in zip(self.activations, self.gradients):
553
+
554
+ print("act shape", act.shape)
555
+ print("grad shape", grad.shape)
556
+
557
+ grad = F.relu(grad)
558
+
559
+ cam = act * grad # shape: [1, heads, seq_len, seq_len]
560
+ cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
561
+ cam = cam.to(torch.float32).detach().cpu()
562
+ cams.append(cam)
563
+
564
+ # cam_sum = F.relu(cam_sum)
565
+ # cam_sum = cam_sum.to(torch.float32)
566
+
567
+ # cams shape: [layers, 1, seq_len, seq_len]
568
+ cam_sum_lst = []
569
+
570
+ start_idx = last + 1
571
+ for i in range(start_idx, cams[0].shape[1]):
572
+ cam_sum = None
573
+ for layer, cam_l in enumerate(cams):
574
+ cam_l_i = cam_l[0, i, :] # shape: [1: seq_len]
575
+
576
+ cam_l_i = cam_l_i[image_mask].unsqueeze(0) # shape: [1, img_seq_len]
577
+ # print(f"layer: {layer}, token index: {i}")
578
+ # print("cam_sum shape: ", cam_l_i.shape)
579
+ num_patches = cam_l_i.shape[-1] # Last dimension of CAM output
580
+ grid_size = int(num_patches ** 0.5)
581
+ # print(f"Detected grid size: {grid_size}x{grid_size}")
582
+
583
+ # Fix the reshaping step dynamically
584
+ cam_reshaped = cam_l_i.view(grid_size, grid_size)
585
+ # print(f"max: {cam_reshaped.max()}, min: {cam_reshaped.min()}")
586
+ # cam_reshaped = (cam_reshaped - cam_reshaped.min()) / (cam_reshaped.max() - cam_reshaped.min())
587
+ if cam_sum == None:
588
+ cam_sum = cam_reshaped
589
+ else:
590
+ cam_sum += cam_reshaped
591
+ # print(f"normalized: max: {cam_normalized.max()}, min: {cam_normalized.min()}")
592
+
593
+ # print(f"sum: max: {cam_sum.max()}, min: {cam_sum.min()}")
594
+ cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
595
+ cam_sum_lst.append(cam_sum)
596
 
597
 
598
  return cam_sum_lst, grid_size, start_idx