AustingDong commited on
Commit
7e57874
·
1 Parent(s): 217eab6

improved Janus

Browse files
app.py CHANGED
@@ -27,7 +27,8 @@ clip_utils.init_Clip()
27
  model_utils, vl_gpt, tokenizer = None, None, None
28
  model_name = "Clip"
29
  language_model_max_layer = 24
30
- language_model_best_layer = 8
 
31
  vision_model_best_layer = 24
32
 
33
  def clean():
@@ -215,7 +216,7 @@ def multimodal_understanding(model_type,
215
  # Gradio interface
216
 
217
  def model_slider_change(model_type):
218
- global model_utils, vl_gpt, tokenizer, clip_utils, model_name, language_model_max_layer, language_model_best_layer, vision_model_best_layer
219
  model_name = model_type
220
 
221
 
@@ -226,13 +227,6 @@ def model_slider_change(model_type):
226
  gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
227
  ]
228
 
229
- visual_res = [
230
- gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="Visualization only", label="response_type"),
231
- gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus"),
232
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
233
- gr.Dropdown(choices=["softmax", "sigmoid"], value="softmax", label="activation function")
234
- ]
235
-
236
  language_res = [
237
  gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
238
  gr.Dropdown(choices=["Language Model"], value="Language Model", label="focus"),
@@ -253,7 +247,7 @@ def model_slider_change(model_type):
253
  return tuple(encoder_only_res + sliders)
254
 
255
  elif model_type.split('-')[0] == "Janus":
256
-
257
  clean()
258
  set_seed()
259
  model_utils = Janus_Utils()
@@ -262,13 +256,14 @@ def model_slider_change(model_type):
262
  layer.self_attn = ModifiedLlamaAttention(layer.self_attn)
263
 
264
  language_model_max_layer = 24
265
- language_model_best_layer = 8
 
266
 
267
  sliders = [
268
- gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
269
- gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max"),
270
  ]
271
- return tuple(visual_res + sliders)
272
 
273
  elif model_type.split('-')[0] == "LLaVA":
274
 
@@ -278,11 +273,12 @@ def model_slider_change(model_type):
278
  version = model_type.split('-')[1]
279
  vl_gpt, tokenizer = model_utils.init_LLaVA(version=version)
280
  language_model_max_layer = 32 if version == "1.5" else 28
281
- language_model_best_layer = 10
 
282
 
283
  sliders = [
284
- gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
285
- gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
286
  ]
287
  return tuple(language_res + sliders)
288
 
@@ -295,11 +291,12 @@ def model_slider_change(model_type):
295
  layer.self_attn = ModifiedGemmaAttention(layer.self_attn)
296
  language_model_max_layer = 18
297
  vision_model_best_layer = 19
298
- language_model_best_layer = 15
 
299
 
300
  sliders = [
301
- gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
302
- gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
303
  ]
304
  return tuple(language_res + sliders)
305
 
@@ -320,15 +317,15 @@ def focus_change(focus):
320
  if response_type.value == "answer + visualization":
321
  res = (
322
  gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
323
- gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
324
- gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers max")
325
  )
326
  return res
327
  else:
328
  res = (
329
  gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
330
- gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
331
- gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer, step=1, label="visualization layers max")
332
  )
333
  return res
334
 
 
27
  model_utils, vl_gpt, tokenizer = None, None, None
28
  model_name = "Clip"
29
  language_model_max_layer = 24
30
+ language_model_best_layer_min = 8
31
+ language_model_best_layer_max = 8
32
  vision_model_best_layer = 24
33
 
34
  def clean():
 
216
  # Gradio interface
217
 
218
  def model_slider_change(model_type):
219
+ global model_utils, vl_gpt, tokenizer, clip_utils, model_name, language_model_max_layer, language_model_best_layer_min, language_model_best_layer_max, vision_model_best_layer
220
  model_name = model_type
221
 
222
 
 
227
  gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
228
  ]
229
 
 
 
 
 
 
 
 
230
  language_res = [
231
  gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
232
  gr.Dropdown(choices=["Language Model"], value="Language Model", label="focus"),
 
247
  return tuple(encoder_only_res + sliders)
248
 
249
  elif model_type.split('-')[0] == "Janus":
250
+ # best seed: 70
251
  clean()
252
  set_seed()
253
  model_utils = Janus_Utils()
 
256
  layer.self_attn = ModifiedLlamaAttention(layer.self_attn)
257
 
258
  language_model_max_layer = 24
259
+ language_model_best_layer_min = 8
260
+ language_model_best_layer_max = 10
261
 
262
  sliders = [
263
+ gr.Slider(minimum=1, maximum=24, value=language_model_best_layer_min, step=1, label="visualization layers min"),
264
+ gr.Slider(minimum=1, maximum=24, value=language_model_best_layer_max, step=1, label="visualization layers max"),
265
  ]
266
+ return tuple(language_res + sliders)
267
 
268
  elif model_type.split('-')[0] == "LLaVA":
269
 
 
273
  version = model_type.split('-')[1]
274
  vl_gpt, tokenizer = model_utils.init_LLaVA(version=version)
275
  language_model_max_layer = 32 if version == "1.5" else 28
276
+ language_model_best_layer_min = 10
277
+ language_model_best_layer_max = 10
278
 
279
  sliders = [
280
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_min, step=1, label="visualization layers min"),
281
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_max, step=1, label="visualization layers max"),
282
  ]
283
  return tuple(language_res + sliders)
284
 
 
291
  layer.self_attn = ModifiedGemmaAttention(layer.self_attn)
292
  language_model_max_layer = 18
293
  vision_model_best_layer = 19
294
+ language_model_best_layer_min = 11
295
+ language_model_best_layer_max = 15
296
 
297
  sliders = [
298
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_min, step=1, label="visualization layers min"),
299
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_max, step=1, label="visualization layers max"),
300
  ]
301
  return tuple(language_res + sliders)
302
 
 
317
  if response_type.value == "answer + visualization":
318
  res = (
319
  gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
320
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_min, step=1, label="visualization layers min"),
321
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_max, step=1, label="visualization layers max")
322
  )
323
  return res
324
  else:
325
  res = (
326
  gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
327
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_min, step=1, label="visualization layers min"),
328
+ gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_max, step=1, label="visualization layers max")
329
  )
330
  return res
331
 
demo/cam.py DELETED
@@ -1,674 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import types
4
- import torch
5
- import torch.nn.functional as F
6
- import matplotlib.pyplot as plt
7
- from PIL import Image
8
- from torch import nn
9
- import spaces
10
- from demo.modify_llama import *
11
-
12
-
13
- class AttentionGuidedCAM:
14
- def __init__(self, model, register=True):
15
- self.model = model
16
- self.gradients = []
17
- self.activations = []
18
- self.hooks = []
19
- if register:
20
- self._register_hooks()
21
-
22
- def _register_hooks(self):
23
- for layer in self.target_layers:
24
- self.hooks.append(layer.register_forward_hook(self._forward_hook))
25
- self.hooks.append(layer.register_backward_hook(self._backward_hook))
26
-
27
- def _forward_hook(self, module, input, output):
28
- self.activations.append(output)
29
-
30
- def _backward_hook(self, module, grad_in, grad_out):
31
- self.gradients.append(grad_out[0])
32
-
33
-
34
- def remove_hooks(self):
35
- for hook in self.hooks:
36
- hook.remove()
37
-
38
- @spaces.GPU(duration=120)
39
- def generate_cam(self, input_tensor, class_idx=None):
40
- raise NotImplementedError
41
-
42
-
43
-
44
-
45
- class AttentionGuidedCAMClip(AttentionGuidedCAM):
46
- def __init__(self, model, target_layers):
47
- self.target_layers = target_layers
48
- super().__init__(model)
49
-
50
- @spaces.GPU(duration=120)
51
- def generate_cam(self, input_tensor, class_idx=None, visual_pooling_method="CLS"):
52
- """ Generates Grad-CAM heatmap for ViT. """
53
-
54
- # Forward pass
55
- output_full = self.model(**input_tensor)
56
-
57
- if class_idx is None:
58
- class_idx = torch.argmax(output_full.logits, dim=1).item()
59
-
60
- if visual_pooling_method == "CLS":
61
- output = output_full.image_embeds
62
- elif visual_pooling_method == "avg":
63
- output = self.model.visual_projection(output_full.vision_model_output.last_hidden_state).mean(dim=1)
64
- else:
65
- # project -> pooling
66
- output, _ = self.model.visual_projection(output_full.vision_model_output.last_hidden_state).max(dim=1)
67
-
68
- # pooling -> project
69
- # output_mx, _ = output_full.vision_model_output.last_hidden_state.max(dim=1)
70
- # output = self.model.visual_projection(output_mx)
71
-
72
- output.backward(output_full.text_embeds[class_idx:class_idx+1], retain_graph=True)
73
-
74
- # Aggregate activations and gradients from ALL layers
75
- self.model.zero_grad()
76
- cam_sum = None
77
- for act, grad in zip(self.activations, self.gradients):
78
-
79
- # act = torch.sigmoid(act[0])
80
- act = F.relu(act[0])
81
-
82
- grad_weights = grad.mean(dim=-1, keepdim=True)
83
-
84
-
85
- print("act shape", act.shape)
86
- print("grad_weights shape", grad_weights.shape)
87
-
88
- # cam = (act * grad_weights).sum(dim=-1)
89
- cam, _ = (act * grad_weights).max(dim=-1)
90
- # cam, _ = act.max(dim=-1)
91
- # cam = cam.unsqueeze(0)
92
- # cam, _ = grad_weights.max(dim=-1)
93
- print("cam_shape: ", cam.shape)
94
-
95
- # Sum across all layers
96
- if cam_sum is None:
97
- cam_sum = cam
98
- else:
99
- cam_sum += cam
100
-
101
-
102
- # Normalize
103
- cam_sum = F.relu(cam_sum)
104
-
105
- # thresholding
106
- cam_sum = cam_sum.to(torch.float32)
107
- percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
108
- cam_sum[cam_sum < percentile] = 0
109
-
110
- # Reshape
111
- print("cam_sum shape: ", cam_sum.shape)
112
- cam_sum = cam_sum[0, 1:]
113
-
114
- num_patches = cam_sum.shape[-1] # Last dimension of CAM output
115
- grid_size = int(num_patches ** 0.5)
116
- print(f"Detected grid size: {grid_size}x{grid_size}")
117
-
118
- cam_sum = cam_sum.view(grid_size, grid_size).detach()
119
- cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
120
-
121
- return cam_sum, output_full, grid_size
122
-
123
-
124
- class AttentionGuidedCAMJanus(AttentionGuidedCAM):
125
- def __init__(self, model, target_layers):
126
- self.target_layers = target_layers
127
- super().__init__(model)
128
- self._modify_layers()
129
- self._register_hooks_activations()
130
-
131
- def _modify_layers(self):
132
- for layer in self.target_layers:
133
- setattr(layer, "attn_gradients", None)
134
- setattr(layer, "attention_map", None)
135
-
136
- layer.save_attn_gradients = types.MethodType(save_attn_gradients, layer)
137
- layer.get_attn_gradients = types.MethodType(get_attn_gradients, layer)
138
- layer.save_attn_map = types.MethodType(save_attn_map, layer)
139
- layer.get_attn_map = types.MethodType(get_attn_map, layer)
140
-
141
- def _forward_activate_hooks(self, module, input, output):
142
- attn_output, attn_weights = output # Unpack outputs
143
- module.save_attn_map(attn_weights)
144
- attn_weights.register_hook(module.save_attn_gradients)
145
-
146
- def _register_hooks_activations(self):
147
- for layer in self.target_layers:
148
- if hasattr(layer, "q_proj"): # is an attention layer
149
- self.hooks.append(layer.register_forward_hook(self._forward_activate_hooks))
150
-
151
- @spaces.GPU(duration=120)
152
- def generate_cam(self, input_tensor, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
153
-
154
- torch.autograd.set_detect_anomaly(True)
155
- for param in self.model.parameters():
156
- param.requires_grad = False
157
-
158
- for layer in self.target_layers:
159
- for param in layer.parameters():
160
- param.requires_grad = True
161
-
162
- # Forward pass
163
- image_embeddings, inputs_embeddings, outputs = self.model(input_tensor, tokenizer, temperature, top_p)
164
-
165
-
166
- input_ids = input_tensor.input_ids
167
-
168
- if focus == "Visual Encoder":
169
- # Pooling
170
- # if visual_pooling_method == "CLS":
171
- # image_embeddings_pooled = image_embeddings[:, 0, :]
172
- # elif visual_pooling_method == "avg":
173
- # image_embeddings_pooled = image_embeddings[:, 1:, :].mean(dim=1)
174
- # elif visual_pooling_method == "max":
175
- # image_embeddings_pooled, _ = image_embeddings[:, 1:, :].max(dim=1)
176
-
177
- # print("image_embeddings_shape: ", image_embeddings_pooled.shape)
178
-
179
-
180
- start_idx = 620
181
- # inputs_embeddings_pooled = inputs_embeddings[:, start_idx: -4].mean(dim=1)
182
- self.model.zero_grad()
183
- # image_embeddings_pooled.backward(inputs_embeddings_pooled, retain_graph=True)
184
-
185
- loss = outputs.logits.max(dim=-1).values[0, start_idx + class_idx]
186
- loss.backward()
187
-
188
- cam_sum = None
189
- for act, grad in zip(self.activations, self.gradients):
190
- # act = torch.sigmoid(act)
191
- act = F.relu(act[0])
192
-
193
-
194
- # Compute mean of gradients
195
- print("grad shape:", grad.shape)
196
- grad_weights = grad.mean(dim=-1, keepdim=True)
197
-
198
- print("act shape", act.shape)
199
- print("grad_weights shape", grad_weights.shape)
200
-
201
- cam, _ = (act * grad_weights).max(dim=-1)
202
- # cam, _ = grad_weights.max(dim=-1)
203
- print(cam.shape)
204
-
205
- # Sum across all layers
206
- if cam_sum is None:
207
- cam_sum = cam
208
- else:
209
- cam_sum += cam
210
-
211
- # Normalize
212
- cam_sum = F.relu(cam_sum)
213
-
214
-
215
- # thresholding
216
- cam_sum = cam_sum.to(torch.float32)
217
- percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
218
- cam_sum[cam_sum < percentile] = 0
219
-
220
- # Reshape
221
- # if visual_pooling_method == "CLS":
222
- cam_sum = cam_sum[0, 1:]
223
- print("cam_sum shape: ", cam_sum.shape)
224
- num_patches = cam_sum.shape[-1] # Last dimension of CAM output
225
- grid_size = int(num_patches ** 0.5)
226
- print(f"Detected grid size: {grid_size}x{grid_size}")
227
-
228
- cam_sum = cam_sum.view(grid_size, grid_size)
229
- cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
230
- cam_sum = cam_sum.detach().to("cpu")
231
-
232
- return cam_sum, grid_size, start_idx
233
-
234
-
235
-
236
-
237
-
238
-
239
- elif focus == "Language Model":
240
- self.model.zero_grad()
241
- loss = outputs.logits.max(dim=-1).values.sum()
242
- loss.backward()
243
-
244
- self.activations = [layer.get_attn_map() for layer in self.target_layers]
245
- self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
246
-
247
- cam_sum = None
248
- for act, grad in zip(self.activations, self.gradients):
249
- # act = torch.sigmoid(act)
250
- print("act_shape:", act.shape)
251
- # print("act1_shape:", act[1].shape)
252
-
253
- act = act.mean(dim=1)
254
-
255
- # Compute mean of gradients
256
- print("grad_shape:", grad.shape)
257
- grad_weights = F.relu(grad.mean(dim=1))
258
-
259
- cam = act * grad_weights
260
- print(cam.shape)
261
-
262
- # Sum across all layers
263
- if cam_sum is None:
264
- cam_sum = cam
265
- else:
266
- cam_sum += cam
267
-
268
- # Normalize
269
- cam_sum = F.relu(cam_sum)
270
-
271
- # thresholding
272
- cam_sum = cam_sum.to(torch.float32)
273
- percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
274
- cam_sum[cam_sum < percentile] = 0
275
-
276
-
277
- # cam_sum shape: [1, seq_len, seq_len]
278
- cam_sum_lst = []
279
- cam_sum_raw = cam_sum
280
- start = 620
281
- for i in range(start, cam_sum_raw.shape[1]):
282
- cam_sum = cam_sum_raw[:, i, :] # shape: [1: seq_len]
283
- cam_sum = cam_sum[input_tensor.images_seq_mask].unsqueeze(0) # shape: [1, 576]
284
- print("cam_sum shape: ", cam_sum.shape)
285
- num_patches = cam_sum.shape[-1] # Last dimension of CAM output
286
- grid_size = int(num_patches ** 0.5)
287
- print(f"Detected grid size: {grid_size}x{grid_size}")
288
-
289
- # Fix the reshaping step dynamically
290
-
291
- cam_sum = cam_sum.view(grid_size, grid_size)
292
- cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
293
- cam_sum = cam_sum.detach().to("cpu")
294
- cam_sum_lst.append(cam_sum)
295
-
296
-
297
- return cam_sum_lst, grid_size, start
298
-
299
-
300
-
301
- class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
302
- def __init__(self, model, target_layers):
303
- self.target_layers = target_layers
304
- super().__init__(model, register=False)
305
- self._modify_layers()
306
- self._register_hooks_activations()
307
-
308
- def _modify_layers(self):
309
- for layer in self.target_layers:
310
- setattr(layer, "attn_gradients", None)
311
- setattr(layer, "attention_map", None)
312
-
313
- layer.save_attn_gradients = types.MethodType(save_attn_gradients, layer)
314
- layer.get_attn_gradients = types.MethodType(get_attn_gradients, layer)
315
- layer.save_attn_map = types.MethodType(save_attn_map, layer)
316
- layer.get_attn_map = types.MethodType(get_attn_map, layer)
317
-
318
- def _forward_activate_hooks(self, module, input, output):
319
- attn_output, attn_weights = output # Unpack outputs
320
- attn_weights.requires_grad_()
321
- module.save_attn_map(attn_weights)
322
- attn_weights.register_hook(module.save_attn_gradients)
323
-
324
- def _register_hooks_activations(self):
325
- for layer in self.target_layers:
326
- if hasattr(layer, "q_proj"): # is an attention layer
327
- self.hooks.append(layer.register_forward_hook(self._forward_activate_hooks))
328
-
329
- @spaces.GPU(duration=120)
330
- def generate_cam(self, inputs, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
331
-
332
- # Forward pass
333
- torch.autograd.set_detect_anomaly(True)
334
- for param in self.model.parameters():
335
- param.requires_grad = False
336
-
337
- for layer in self.target_layers:
338
- for param in layer.parameters():
339
- param.requires_grad = True
340
-
341
- outputs_raw = self.model(**inputs)
342
-
343
- self.model.zero_grad()
344
- print("outputs_raw", outputs_raw)
345
-
346
- loss = outputs_raw.logits.max(dim=-1).values.sum()
347
- loss.backward()
348
-
349
- # get image masks
350
- image_mask = []
351
- last = 0
352
- for i in range(inputs["input_ids"].shape[1]):
353
- decoded_token = tokenizer.decode(inputs["input_ids"][0][i].item())
354
- if (decoded_token == "<image>"):
355
- image_mask.append(True)
356
- last = i
357
- else:
358
- image_mask.append(False)
359
-
360
-
361
- # Aggregate activations and gradients from ALL layers
362
- self.activations = [layer.get_attn_map() for layer in self.target_layers]
363
- self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
364
- cam_sum = None
365
-
366
- for act, grad in zip(self.activations, self.gradients):
367
-
368
- print("act shape", act.shape)
369
- print("grad shape", grad.shape)
370
-
371
- grad = F.relu(grad)
372
-
373
-
374
- cam = act * grad # shape: [1, heads, seq_len, seq_len]
375
- cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
376
-
377
- # Sum across all layers
378
- if cam_sum is None:
379
- cam_sum = cam
380
- else:
381
- cam_sum += cam
382
-
383
- cam_sum = F.relu(cam_sum)
384
- cam_sum = cam_sum.to(torch.float32)
385
-
386
-
387
- # cam_sum shape: [1, seq_len, seq_len]
388
- cam_sum_lst = []
389
- cam_sum_raw = cam_sum
390
- start_idx = last + 1
391
- for i in range(start_idx, cam_sum_raw.shape[1]):
392
- cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
393
-
394
- cam_sum = cam_sum[image_mask].unsqueeze(0) # shape: [1, img_seq_len]
395
- print("cam_sum shape: ", cam_sum.shape)
396
- num_patches = cam_sum.shape[-1] # Last dimension of CAM output
397
- grid_size = int(num_patches ** 0.5)
398
- print(f"Detected grid size: {grid_size}x{grid_size}")
399
-
400
- cam_sum = cam_sum.view(grid_size, grid_size)
401
- cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
402
- cam_sum_lst.append(cam_sum)
403
-
404
-
405
- return cam_sum_lst, grid_size, start_idx
406
-
407
-
408
-
409
-
410
-
411
-
412
- class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
413
- def __init__(self, model, target_layers):
414
- self.target_layers = target_layers
415
- super().__init__(model, register=True)
416
- self._modify_layers()
417
- self._register_hooks_activations()
418
-
419
- def _modify_layers(self):
420
- for layer in self.target_layers:
421
- setattr(layer, "attn_gradients", None)
422
- setattr(layer, "attention_map", None)
423
-
424
- layer.save_attn_gradients = types.MethodType(save_attn_gradients, layer)
425
- layer.get_attn_gradients = types.MethodType(get_attn_gradients, layer)
426
- layer.save_attn_map = types.MethodType(save_attn_map, layer)
427
- layer.get_attn_map = types.MethodType(get_attn_map, layer)
428
-
429
- def _forward_activate_hooks(self, module, input, output):
430
- attn_output, attn_weights = output # Unpack outputs
431
- print("attn_output shape:", attn_output.shape)
432
- print("attn_weights shape:", attn_weights.shape)
433
- module.save_attn_map(attn_weights)
434
- attn_weights.register_hook(module.save_attn_gradients)
435
-
436
- def _register_hooks_activations(self):
437
- for layer in self.target_layers:
438
- if hasattr(layer, "q_proj"): # is an attention layer
439
- self.hooks.append(layer.register_forward_hook(self._forward_activate_hooks))
440
-
441
- @spaces.GPU(duration=120)
442
- def generate_cam(self, inputs, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
443
-
444
- # Forward pass
445
- torch.autograd.set_detect_anomaly(True)
446
- for param in self.model.parameters():
447
- param.requires_grad = False
448
-
449
- for layer in self.target_layers:
450
- for param in layer.parameters():
451
- param.requires_grad = True
452
-
453
- outputs_raw = self.model(**inputs, output_hidden_states=True)
454
-
455
-
456
-
457
- # get image masks
458
- image_mask = []
459
- last = 0
460
- for i in range(inputs["input_ids"].shape[1]):
461
- decoded_token = tokenizer.decode(inputs["input_ids"][0][i].item())
462
- if (decoded_token == "<image>"):
463
- image_mask.append(True)
464
- last = i
465
- else:
466
- image_mask.append(False)
467
- start_idx = last + 1
468
-
469
-
470
-
471
- if focus == "Visual Encoder":
472
- # image_embeddings = outputs_raw.image_hidden_states
473
- # inputs_embeddings = outputs_raw.hidden_states[0]
474
- # # Pooling
475
- # if visual_pooling_method == "avg":
476
- # image_embeddings_pooled = image_embeddings.mean(dim=1) # end of image: 618
477
- # elif visual_pooling_method == "max":
478
- # image_embeddings_pooled, _ = image_embeddings.max(dim=1)
479
-
480
- # print("image_embeddings_shape: ", image_embeddings_pooled.shape)
481
-
482
-
483
-
484
- # inputs_embeddings_pooled = inputs_embeddings[:, start_idx:].mean(dim=1)
485
- self.model.zero_grad()
486
- # image_embeddings_pooled.backward(inputs_embeddings_pooled, retain_graph=True)
487
-
488
- loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + class_idx]
489
- loss.backward()
490
-
491
- cam_sum = None
492
- for act, grad in zip(self.activations, self.gradients):
493
- # act = torch.sigmoid(act)
494
- act = F.relu(act[0])
495
-
496
-
497
- # Compute mean of gradients
498
- print("grad shape:", grad.shape)
499
- grad_weights = grad.mean(dim=-1, keepdim=True)
500
-
501
- print("act shape", act.shape)
502
- print("grad_weights shape", grad_weights.shape)
503
-
504
- cam = (act * grad_weights).sum(dim=-1)
505
- # cam, _ = (act * grad_weights).max(dim=-1)
506
- # cam, _ = grad_weights.max(dim=-1)
507
- print(cam.shape)
508
-
509
- # Sum across all layers
510
- if cam_sum is None:
511
- cam_sum = cam
512
- else:
513
- cam_sum += cam
514
-
515
- # Normalize
516
- cam_sum = F.relu(cam_sum)
517
-
518
-
519
- # thresholding
520
- cam_sum = cam_sum.to(torch.float32).detach().cpu()
521
- percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
522
- cam_sum[cam_sum < percentile] = 0
523
-
524
- # Reshape
525
- print("cam_sum shape: ", cam_sum.shape)
526
- num_patches = cam_sum.shape[-1] # Last dimension of CAM output
527
- grid_size = int(num_patches ** 0.5)
528
- print(f"Detected grid size: {grid_size}x{grid_size}")
529
-
530
- cam_sum = cam_sum.view(grid_size, grid_size)
531
- cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
532
-
533
- return cam_sum, grid_size, start_idx
534
-
535
- elif focus == "Language Model":
536
- self.model.zero_grad()
537
- print("logits shape:", outputs_raw.logits.shape)
538
- # loss = outputs_raw.logits.max(dim=-1).values.sum()
539
- if class_idx == -1:
540
- loss = outputs_raw.logits.max(dim=-1).values.sum()
541
- else:
542
- loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + class_idx]
543
- loss.backward()
544
-
545
-
546
-
547
- # Aggregate activations and gradients from ALL layers
548
- self.activations = [layer.get_attn_map() for layer in self.target_layers]
549
- self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
550
- print(f"layers shape: {len(self.target_layers)}")
551
- print("activations & gradients shape", len(self.activations), len(self.gradients))
552
-
553
- cams = []
554
-
555
- # Ver 2
556
- for act, grad in zip(self.activations, self.gradients):
557
-
558
- print("act shape", act.shape)
559
- print("grad shape", grad.shape)
560
-
561
- grad = F.relu(grad)
562
-
563
- # cam = grad
564
- cam = act * grad # shape: [1, heads, seq_len, seq_len]
565
- cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
566
- cam = cam.to(torch.float32).detach().cpu()
567
- cams.append(cam)
568
-
569
- # cam_sum = F.relu(cam_sum)
570
- # cam_sum = cam_sum.to(torch.float32)
571
-
572
- # cams shape: [layers, 1, seq_len, seq_len]
573
- cam_sum_lst = []
574
-
575
- start_idx = last + 1
576
- for i in range(start_idx, cams[0].shape[1]):
577
- cam_sum = None
578
- for layer, cam_l in enumerate(cams):
579
- cam_l_i = cam_l[0, i, :] # shape: [1: seq_len]
580
-
581
- cam_l_i = cam_l_i[image_mask].unsqueeze(0) # shape: [1, img_seq_len]
582
- # print(f"layer: {layer}, token index: {i}")
583
- # print("cam_sum shape: ", cam_l_i.shape)
584
- num_patches = cam_l_i.shape[-1] # Last dimension of CAM output
585
- grid_size = int(num_patches ** 0.5)
586
- # print(f"Detected grid size: {grid_size}x{grid_size}")
587
-
588
- # Fix the reshaping step dynamically
589
- cam_reshaped = cam_l_i.view(grid_size, grid_size)
590
- # print(f"max: {cam_reshaped.max()}, min: {cam_reshaped.min()}")
591
- # cam_reshaped = (cam_reshaped - cam_reshaped.min()) / (cam_reshaped.max() - cam_reshaped.min())
592
- if cam_sum == None:
593
- cam_sum = cam_reshaped
594
- else:
595
- cam_sum += cam_reshaped
596
- # print(f"normalized: max: {cam_normalized.max()}, min: {cam_normalized.min()}")
597
-
598
- # print(f"sum: max: {cam_sum.max()}, min: {cam_sum.min()}")
599
- cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
600
- cam_sum_lst.append(cam_sum)
601
-
602
-
603
- return cam_sum_lst, grid_size, start_idx
604
-
605
-
606
-
607
-
608
-
609
-
610
-
611
-
612
-
613
-
614
-
615
-
616
-
617
-
618
-
619
-
620
- def generate_gradcam(
621
- cam,
622
- image,
623
- size = (384, 384),
624
- alpha=0.5,
625
- colormap=cv2.COLORMAP_JET,
626
- aggregation='mean',
627
- normalize=False
628
- ):
629
- """
630
- Generates a Grad-CAM heatmap overlay on top of the input image.
631
-
632
- Parameters:
633
- attributions (torch.Tensor): A tensor of shape (C, H, W) representing the
634
- intermediate activations or gradients at the target layer.
635
- image (PIL.Image): The original image.
636
- alpha (float): The blending factor for the heatmap overlay (default 0.5).
637
- colormap (int): OpenCV colormap to apply (default cv2.COLORMAP_JET).
638
- aggregation (str): How to aggregate across channels; either 'mean' or 'sum'.
639
-
640
- Returns:
641
- PIL.Image: The image overlaid with the Grad-CAM heatmap.
642
- """
643
- # print("Generating Grad-CAM with shape:", cam.shape)
644
-
645
- if normalize:
646
- cam_min, cam_max = cam.min(), cam.max()
647
- cam = cam - cam_min
648
- cam = cam / (cam_max - cam_min)
649
- # Convert tensor to numpy array
650
- cam = torch.nn.functional.interpolate(cam.unsqueeze(0).unsqueeze(0), size=size, mode='bilinear').squeeze()
651
- cam_np = cam.squeeze().detach().cpu().numpy()
652
-
653
- # Apply Gaussian blur for smoother heatmaps
654
- cam_np = cv2.GaussianBlur(cam_np, (5,5), sigmaX=0.8)
655
-
656
- # Resize the cam to match the image size
657
- width, height = size
658
- cam_resized = cv2.resize(cam_np, (width, height))
659
-
660
- # Convert the normalized map to a heatmap (0-255 uint8)
661
- heatmap = np.uint8(255 * cam_resized)
662
- heatmap = cv2.applyColorMap(heatmap, colormap)
663
- # OpenCV produces heatmaps in BGR, so convert to RGB for consistency
664
- heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
665
-
666
- # Convert original image to a numpy array
667
- image_np = np.array(image)
668
- image_np = cv2.resize(image_np, (width, height))
669
-
670
- # Blend the heatmap with the original image
671
- overlay = cv2.addWeighted(image_np, 1 - alpha, heatmap, alpha, 0)
672
-
673
- return Image.fromarray(overlay)
674
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/visualization.py CHANGED
@@ -145,7 +145,7 @@ class Visualization:
145
  return cams
146
 
147
 
148
- def process(self, cam_sum, thresholding=True, remove_cls=True, normalize=True):
149
 
150
  cam_sum = cam_sum.to(torch.float32)
151
 
 
145
  return cams
146
 
147
 
148
+ def process(self, cam_sum, thresholding=True, remove_cls=False, normalize=True):
149
 
150
  cam_sum = cam_sum.to(torch.float32)
151
 
janus/models/modeling_vlm.py CHANGED
@@ -256,7 +256,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
256
  inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
257
 
258
  # replace with the image embeddings
259
- images_embeds = images_embeds[:, 1:, :]
260
  inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
261
 
262
  return inputs_embeds
@@ -293,7 +293,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
293
  inputs_embeds = self.language_model.get_input_embeddings()(input_tensor.input_ids)
294
  # print("input_embeddings: ", inputs_embeds)
295
 
296
- images_embeds_rest = images_embeds[:, 1:, :]
 
297
 
298
  # images_embeds_pooled = images_embeds.mean(dim=1)
299
 
 
256
  inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
257
 
258
  # replace with the image embeddings
259
+ # images_embeds = images_embeds[:, 1:, :]
260
  inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
261
 
262
  return inputs_embeds
 
293
  inputs_embeds = self.language_model.get_input_embeddings()(input_tensor.input_ids)
294
  # print("input_embeddings: ", inputs_embeds)
295
 
296
+ # images_embeds_rest = images_embeds[:, 1:, :]
297
+ images_embeds_rest = images_embeds[:, :, :]
298
 
299
  # images_embeds_pooled = images_embeds.mean(dim=1)
300
 
janus/models/siglip_vit.py CHANGED
@@ -655,9 +655,9 @@ def create_siglip_vit(
655
  else:
656
  layers = min(vision_cfg.layers, select_layer)
657
 
658
- # Requre CLS token
659
- vision_cfg.class_token = True
660
- print("Usage Class Token: ", vision_cfg.class_token)
661
 
662
  model = VisionTransformer(
663
  img_size=image_size,
 
655
  else:
656
  layers = min(vision_cfg.layers, select_layer)
657
 
658
+ # Require CLS token
659
+ # vision_cfg.class_token = True
660
+ # print("Usage Class Token: ", vision_cfg.class_token)
661
 
662
  model = VisionTransformer(
663
  img_size=image_size,
questions/VLAT.py CHANGED
@@ -49,7 +49,7 @@ VLAT_questions=[
49
 
50
  [
51
  "StackedArea",
52
- "The number of girls named 'Olivia' was raising or falling from 2009 to 2012?",
53
  "images/mini-VLAT/StackedArea.png"
54
  ],
55
 
@@ -115,7 +115,7 @@ VLAT_questions=[
115
 
116
  [
117
  "LineChart",
118
- "Over the course of the first quarter of 2020, the price of a barrel of oil was rising or falling?",
119
  "images/mini-VLAT/LineChart.png"
120
  ],
121
 
@@ -175,7 +175,7 @@ VLAT_questions=[
175
 
176
  [
177
  "AreaChart",
178
- "Over the first six months of 2018, the price of a pound of coffee beans was roughly falling or rising?",
179
  "images/mini-VLAT/AreaChart.png"
180
  ],
181
 
 
49
 
50
  [
51
  "StackedArea",
52
+ "The number of girls named 'Olivia' was increasing or decreasing from 2009 to 2012?",
53
  "images/mini-VLAT/StackedArea.png"
54
  ],
55
 
 
115
 
116
  [
117
  "LineChart",
118
+ "Over the course of the first quarter of 2020, the price of a barrel of oil was increasing or decreasing?",
119
  "images/mini-VLAT/LineChart.png"
120
  ],
121
 
 
175
 
176
  [
177
  "AreaChart",
178
+ "Over the first six months of 2018, the price of a pound of coffee beans was roughly decreasing or increasing?",
179
  "images/mini-VLAT/AreaChart.png"
180
  ],
181