drhead commited on
Commit
5d0540f
·
1 Parent(s): e6d942c

Add attention visualization

Browse files
Files changed (1) hide show
  1. app.py +172 -45
app.py CHANGED
@@ -1,17 +1,17 @@
1
- import json
2
-
3
- import gradio as gr
4
  from PIL import Image
5
- import safetensors.torch
6
- import spaces
7
- import timm
8
- from timm.models import VisionTransformer
9
  import torch
10
  from torchvision.transforms import transforms
11
  from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
13
-
14
- torch.set_grad_enabled(False)
 
 
 
 
15
 
16
  class Fit(torch.nn.Module):
17
  def __init__(
@@ -137,79 +137,206 @@ class GatedHead(torch.nn.Module):
137
 
138
  model.head = GatedHead(min(model.head.weight.shape), 9083)
139
 
140
- safetensors.torch.load_model(model, "JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors")
 
 
 
 
 
 
141
  model.eval()
142
 
143
- with open("tagger_tags.json", "r") as file:
144
- tags = json.load(file) # type: dict
145
- allowed_tags = list(tags.keys())
146
 
147
- for idx, tag in enumerate(allowed_tags):
148
- allowed_tags[idx] = tag.replace("_", " ")
149
 
150
- sorted_tag_score = {}
151
 
152
  @spaces.GPU(duration=5)
153
- def run_classifier(image, threshold):
154
- global sorted_tag_score
155
  img = image.convert('RGBA')
156
  tensor = transform(img).unsqueeze(0)
157
 
158
  with torch.no_grad():
159
- probits = model(tensor)[0]
160
- values, indices = probits.topk(250)
 
 
161
 
162
- tag_score = dict()
163
- for i in range(indices.size(0)):
164
- tag_score[allowed_tags[indices[i]]] = values[i].item()
165
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
166
 
167
- return create_tags(threshold)
168
 
169
- def create_tags(threshold):
170
- global sorted_tag_score
171
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
172
  text_no_impl = ", ".join(filtered_tag_score.keys())
173
  return text_no_impl, filtered_tag_score
174
 
175
  def clear_image():
176
- global sorted_tag_score
177
- sorted_tag_score = {}
178
- return "", {}
179
 
180
- with gr.Blocks(css=".output-class { display: none; }") as demo:
181
- gr.Markdown("""
182
- ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
183
- This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
184
 
185
- This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
 
186
 
187
- Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
188
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  with gr.Row():
190
  with gr.Column():
191
- image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
192
- threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
 
193
  with gr.Column():
 
194
  tag_string = gr.Textbox(label="Tag String")
195
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
196
 
197
- image_input.upload(
 
 
 
 
 
 
 
198
  fn=run_classifier,
199
- inputs=[image_input, threshold_slider],
200
- outputs=[tag_string, label_box]
 
201
  )
202
 
203
- image_input.clear(
204
  fn=clear_image,
205
  inputs=[],
206
- outputs=[tag_string, label_box]
207
  )
208
 
209
  threshold_slider.input(
210
  fn=create_tags,
211
- inputs=[threshold_slider],
212
- outputs=[tag_string, label_box]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  )
214
 
215
  if __name__ == "__main__":
 
 
 
 
1
  from PIL import Image
2
+ import numpy as np
3
+ import matplotlib.cm as cm
4
+ import msgspec
 
5
  import torch
6
  from torchvision.transforms import transforms
7
  from torchvision.transforms import InterpolationMode
8
  import torchvision.transforms.functional as TF
9
+ import timm
10
+ from timm.models import VisionTransformer
11
+ import safetensors.torch
12
+ import gradio as gr
13
+ import spaces
14
+ from huggingface_hub import hf_hub_download
15
 
16
  class Fit(torch.nn.Module):
17
  def __init__(
 
137
 
138
  model.head = GatedHead(min(model.head.weight.shape), 9083)
139
 
140
+ cached_model = hf_hub_download(
141
+ repo_id="RedRocket/JointTaggerProject",
142
+ subfolder="JTP_PILOT2",
143
+ filename="JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors"
144
+ )
145
+
146
+ safetensors.torch.load_model(model, cached_model)
147
  model.eval()
148
 
149
+ with open("tagger_tags.json", "rb") as file:
150
+ tags = msgspec.json.decode(file.read(), type=dict[str, int])
 
151
 
152
+ for tag in list(tags.keys()):
153
+ tags[tag.replace("_", " ")] = tags.pop(tag)
154
 
155
+ allowed_tags = list(tags.keys())
156
 
157
  @spaces.GPU(duration=5)
158
+ def run_classifier(image: Image.Image, threshold):
 
159
  img = image.convert('RGBA')
160
  tensor = transform(img).unsqueeze(0)
161
 
162
  with torch.no_grad():
163
+ probits = model(tensor)[0] # type: torch.Tensor
164
+ values, indices = probits.cpu().topk(250)
165
+
166
+ tag_score = {allowed_tags[idx.item()]: val.item() for idx, val in zip(indices, values)}
167
 
 
 
 
168
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
169
 
170
+ return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
171
 
172
+ def create_tags(threshold, sorted_tag_score: dict):
 
173
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
174
  text_no_impl = ", ".join(filtered_tag_score.keys())
175
  return text_no_impl, filtered_tag_score
176
 
177
  def clear_image():
178
+ return "", {}, None, {}, None
 
 
179
 
180
+ @spaces.GPU(duration=5)
181
+ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
182
+ target_tag_index = tags[evt.value]
183
+ tensor = transform(img).unsqueeze(0)
184
 
185
+ gradients = {}
186
+ activations = {}
187
 
188
+ def hook_forward(module, input, output):
189
+ activations['value'] = output
190
+
191
+ def hook_backward(module, grad_in, grad_out):
192
+ gradients['value'] = grad_out[0]
193
+
194
+ handle_forward = model.norm.register_forward_hook(hook_forward)
195
+ handle_backward = model.norm.register_full_backward_hook(hook_backward)
196
+
197
+ probits = model(tensor)[0]
198
+
199
+ model.zero_grad()
200
+ probits[target_tag_index].backward(retain_graph=True)
201
+
202
+ with torch.no_grad():
203
+ patch_grads = gradients.get('value')
204
+ patch_acts = activations.get('value')
205
+
206
+ weights = torch.mean(patch_grads, dim=1).squeeze(0)
207
+
208
+ cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
209
+ cam_1d = torch.relu(cam_1d)
210
+
211
+ cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
212
+
213
+ handle_forward.remove()
214
+ handle_backward.remove()
215
+
216
+ return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
217
+
218
+ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
219
+ """
220
+ Overlays CAM on image and returns a PIL image.
221
+ Args:
222
+ image_pil: PIL Image (RGB)
223
+ cam: 2D numpy array (activation map)
224
+ alpha: float, blending factor
225
+ vis_threshold: float, minimum normalized CAM value to show color
226
+ Returns:
227
+ PIL.Image.Image with overlay
228
+ """
229
+ if cam is None:
230
+ return image_pil
231
+ w, h = image_pil.size
232
+ size = max(w, h)
233
+
234
+ # Normalize CAM to [0, 1]
235
+ cam -= cam.min()
236
+ cam /= cam.max()
237
+
238
+ # Create heatmap using matplotlib colormap
239
+ colormap = cm.get_cmap('inferno')
240
+ cam_rgb = colormap(cam)[:, :, :3] # RGB
241
+
242
+ # Create alpha channel
243
+ cam_alpha = (cam >= vis_threshold).astype(np.float32) * alpha # Alpha mask
244
+ cam_rgba = np.dstack((cam_rgb, cam_alpha)) # Shape: (H, W, 4)
245
+
246
+ # Coarse upscale for CAM output -- keeps "blocky" effect that is truer to what is measured
247
+ cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA")
248
+ cam_pil = cam_pil.resize((216,216), resample=Image.Resampling.NEAREST)
249
+
250
+ # Model uses padded image as input, this matches attention map to input image aspect ratio
251
+ cam_pil = cam_pil.resize((size, size), resample=Image.Resampling.BICUBIC)
252
+ cam_pil = transforms.CenterCrop((h, w))(cam_pil)
253
+
254
+ # Composite over original
255
+ composite = Image.alpha_composite(image_pil, cam_pil)
256
+
257
+ return composite
258
+
259
+ custom_css = """
260
+ .output-class { display: none; }
261
+ .inferno-slider input[type=range] {
262
+ background: linear-gradient(to right,
263
+ #000004, #1b0c41, #4a0c6b, #781c6d,
264
+ #a52c60, #cf4446, #ed6925, #fb9b06,
265
+ #f7d13d, #fcffa4
266
+ ) !important;
267
+ background-size: 100% 100% !important;
268
+ }
269
+ #image_container-image {
270
+ width: 100%;
271
+ aspect-ratio: 1 / 1;
272
+ max-height: 100%;
273
+ }
274
+ #image_container img {
275
+ object-fit: contain !important;
276
+ }
277
+ """
278
+
279
+ with gr.Blocks(css=custom_css) as demo:
280
+ gr.Markdown("## Joint Tagger Project: JTP-PILOT² Demo **BETA**")
281
+ original_image_state = gr.State() # stash a copy of the input image
282
+ sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
283
+ cam_state = gr.State()
284
  with gr.Row():
285
  with gr.Column():
286
+ image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
287
+ cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
288
+ alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
289
  with gr.Column():
290
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
291
  tag_string = gr.Textbox(label="Tag String")
292
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
293
 
294
+ gr.Markdown("""
295
+ This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
296
+ This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
297
+ Thanks to metal63 for providing initial code for attention visualization (click a tag in the tag list to try it out!)
298
+ Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
299
+ """)
300
+
301
+ image.upload(
302
  fn=run_classifier,
303
+ inputs=[image, threshold_slider],
304
+ outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state],
305
+ show_progress='minimal'
306
  )
307
 
308
+ image.clear(
309
  fn=clear_image,
310
  inputs=[],
311
+ outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
312
  )
313
 
314
  threshold_slider.input(
315
  fn=create_tags,
316
+ inputs=[threshold_slider, sorted_tag_score_state],
317
+ outputs=[tag_string, label_box],
318
+ show_progress='hidden'
319
+ )
320
+
321
+ label_box.select(
322
+ fn=cam_inference,
323
+ inputs=[original_image_state, cam_slider, alpha_slider],
324
+ outputs=[image, cam_state],
325
+ show_progress='minimal'
326
+ )
327
+
328
+ cam_slider.input(
329
+ fn=create_cam_visualization_pil,
330
+ inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
331
+ outputs=[image],
332
+ show_progress='hidden'
333
+ )
334
+
335
+ alpha_slider.input(
336
+ fn=create_cam_visualization_pil,
337
+ inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
338
+ outputs=[image],
339
+ show_progress='hidden'
340
  )
341
 
342
  if __name__ == "__main__":