drhead commited on
Commit
3cb1c16
·
verified ·
1 Parent(s): 02a9646

fix leaky globals

Browse files
Files changed (1) hide show
  1. app.py +41 -66
app.py CHANGED
@@ -154,19 +154,15 @@ allowed_tags = list(tags.keys())
154
  for idx, tag in enumerate(allowed_tags):
155
  allowed_tags[idx] = tag.replace("_", " ")
156
 
157
- sorted_tag_score = {}
158
- input_image = None
159
 
160
 
161
  @spaces.GPU(duration=5)
162
- def run_classifier(image, threshold):
163
- global sorted_tag_score, input_image
164
- input_image = image.convert('RGBA')
165
- img = input_image
166
  tensor = transform(img).unsqueeze(0)
167
 
168
  with torch.no_grad():
169
- probits = model(tensor)[0]
170
  values, indices = probits.topk(250)
171
 
172
  tag_score = dict()
@@ -174,37 +170,18 @@ def run_classifier(image, threshold):
174
  tag_score[allowed_tags[indices[i]]] = values[i].item()
175
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
176
 
177
- return create_tags(threshold)
178
 
179
- def create_tags(threshold):
180
- global sorted_tag_score
181
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
182
  text_no_impl = ", ".join(filtered_tag_score.keys())
183
  return text_no_impl, filtered_tag_score
184
 
185
  def clear_image():
186
- global sorted_tag_score, input_image
187
- input_image = None
188
- sorted_tag_score = {}
189
- return "", {}
190
 
191
- target_tag_index = None
192
-
193
- # Store hooks and intermediate values
194
- gradients = {}
195
- activations = {}
196
-
197
- def hook_forward(module, input, output):
198
- activations['value'] = output
199
-
200
- def hook_backward(module, grad_in, grad_out):
201
- gradients['value'] = grad_out[0]
202
-
203
- def cam_inference(threshold, evt: gr.SelectData):
204
  target_tag = evt.value
205
- print(f"target_tag: {target_tag}")
206
- global input_image, sorted_tag_score, target_tag_index, gradients, activations
207
- img = input_image
208
  tensor = transform(img).unsqueeze(0)
209
 
210
  gradients = {}
@@ -212,46 +189,44 @@ def cam_inference(threshold, evt: gr.SelectData):
212
  cam = None
213
  target_tag_index = None
214
 
215
- if target_tag:
216
- if target_tag not in allowed_tags:
217
- print(f"Warning: Target tag '{target_tag}' not found in allowed tags.")
218
- target_tag = None
219
- else:
220
- target_tag_index = allowed_tags.index(target_tag)
221
- handle_forward = model.norm.register_forward_hook(hook_forward)
222
- handle_backward = model.norm.register_full_backward_hook(hook_backward)
223
 
224
- probits = model(tensor)[0].cpu()
225
-
226
-
227
- if target_tag is not None and target_tag_index is not None:
228
- model.zero_grad()
229
- target_score = probits[target_tag_index]
230
- target_score.backward(retain_graph=True)
 
 
231
 
232
- grads = gradients.get('value')
233
- acts = activations.get('value')
 
 
 
234
 
235
- if grads is not None and acts is not None:
236
- patch_grads = grads
237
- patch_acts = acts
238
 
239
- weights = torch.mean(patch_grads, dim=1).squeeze(0)
 
240
 
241
- cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
242
- cam_1d = torch.relu(cam_1d)
243
 
244
- cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
 
245
 
 
246
 
247
- handle_forward.remove()
248
- handle_backward.remove()
249
- gradients = {}
250
- activations = {}
251
 
252
- return create_cam_visualization_pil(cam, vis_threshold=threshold)
253
 
254
- def create_cam_visualization_pil(cam, alpha=0.6, vis_threshold=0.2):
255
  """
256
  Overlays CAM on image and returns a PIL image.
257
 
@@ -265,9 +240,6 @@ def create_cam_visualization_pil(cam, alpha=0.6, vis_threshold=0.2):
265
  PIL.Image.Image with overlay
266
  """
267
 
268
- global input_image
269
- # Convert to RGB (in case RGBA or others)
270
- image_pil = input_image
271
  w, h = image_pil.size
272
 
273
  # Resize CAM to match image
@@ -297,8 +269,11 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
297
 
298
  This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
299
 
 
 
300
  Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
301
  """)
 
302
  with gr.Row():
303
  with gr.Column():
304
  image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
@@ -310,13 +285,13 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
310
  image_input.upload(
311
  fn=run_classifier,
312
  inputs=[image_input, threshold_slider],
313
- outputs=[tag_string, label_box]
314
  )
315
 
316
  image_input.clear(
317
  fn=clear_image,
318
  inputs=[],
319
- outputs=[tag_string, label_box]
320
  )
321
 
322
  threshold_slider.input(
@@ -327,7 +302,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
327
 
328
  label_box.select(
329
  fn=cam_inference,
330
- inputs=[threshold_slider],
331
  outputs=[image_input]
332
  )
333
 
 
154
  for idx, tag in enumerate(allowed_tags):
155
  allowed_tags[idx] = tag.replace("_", " ")
156
 
 
 
157
 
158
 
159
  @spaces.GPU(duration=5)
160
+ def run_classifier(image: Image.Image, threshold):
161
+ img = image.convert('RGBA')
 
 
162
  tensor = transform(img).unsqueeze(0)
163
 
164
  with torch.no_grad():
165
+ probits = model(tensor)[0] # type: torch.Tensor
166
  values, indices = probits.topk(250)
167
 
168
  tag_score = dict()
 
170
  tag_score[allowed_tags[indices[i]]] = values[i].item()
171
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
172
 
173
+ return *create_tags(threshold, sorted_tag_score), img
174
 
175
+ def create_tags(threshold, sorted_tag_score: dict):
 
176
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
177
  text_no_impl = ", ".join(filtered_tag_score.keys())
178
  return text_no_impl, filtered_tag_score
179
 
180
  def clear_image():
181
+ return "", {}, None
 
 
 
182
 
183
+ def cam_inference(img, threshold, evt: gr.SelectData):
 
 
 
 
 
 
 
 
 
 
 
 
184
  target_tag = evt.value
 
 
 
185
  tensor = transform(img).unsqueeze(0)
186
 
187
  gradients = {}
 
189
  cam = None
190
  target_tag_index = None
191
 
 
 
 
 
 
 
 
 
192
 
193
+ def hook_forward(module, input, output):
194
+ activations['value'] = output
195
+
196
+ def hook_backward(module, grad_in, grad_out):
197
+ gradients['value'] = grad_out[0]
198
+
199
+ target_tag_index = allowed_tags.index(target_tag)
200
+ handle_forward = model.norm.register_forward_hook(hook_forward)
201
+ handle_backward = model.norm.register_full_backward_hook(hook_backward)
202
 
203
+ probits = model(tensor)[0].cpu()
204
+
205
+ model.zero_grad()
206
+ target_score = probits[target_tag_index]
207
+ target_score.backward(retain_graph=True)
208
 
209
+ grads = gradients.get('value')
210
+ acts = activations.get('value')
 
211
 
212
+ patch_grads = grads
213
+ patch_acts = acts
214
 
215
+ weights = torch.mean(patch_grads, dim=1).squeeze(0)
 
216
 
217
+ cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
218
+ cam_1d = torch.relu(cam_1d)
219
 
220
+ cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
221
 
222
+ handle_forward.remove()
223
+ handle_backward.remove()
224
+ gradients = {}
225
+ activations = {}
226
 
227
+ return create_cam_visualization_pil(img, cam, vis_threshold=threshold)
228
 
229
+ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
230
  """
231
  Overlays CAM on image and returns a PIL image.
232
 
 
240
  PIL.Image.Image with overlay
241
  """
242
 
 
 
 
243
  w, h = image_pil.size
244
 
245
  # Resize CAM to match image
 
269
 
270
  This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
271
 
272
+ Thanks to metal63 for providing initial code for attention visualization (click a tag in the tag list to try it out!)
273
+
274
  Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
275
  """)
276
+ original_image_state = gr.State() # stash a copy of the input image
277
  with gr.Row():
278
  with gr.Column():
279
  image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
 
285
  image_input.upload(
286
  fn=run_classifier,
287
  inputs=[image_input, threshold_slider],
288
+ outputs=[tag_string, label_box, original_image_state]
289
  )
290
 
291
  image_input.clear(
292
  fn=clear_image,
293
  inputs=[],
294
+ outputs=[tag_string, label_box, original_image_state]
295
  )
296
 
297
  threshold_slider.input(
 
302
 
303
  label_box.select(
304
  fn=cam_inference,
305
+ inputs=[original_image_state, threshold_slider],
306
  outputs=[image_input]
307
  )
308