drhead commited on
Commit
b545f1a
·
verified ·
1 Parent(s): 9969656

add sliders for cam alpha/thresholding

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -12,7 +12,7 @@ from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
13
  from huggingface_hub import hf_hub_download
14
  import numpy as np
15
- import matplotlib.cm as cm
16
 
17
  class Fit(torch.nn.Module):
18
  def __init__(
@@ -178,9 +178,9 @@ def create_tags(threshold, sorted_tag_score: dict):
178
  return text_no_impl, filtered_tag_score
179
 
180
  def clear_image():
181
- return "", {}, None, {}
182
 
183
- def cam_inference(img, threshold, evt: gr.SelectData):
184
  target_tag = evt.value
185
  tensor = transform(img).unsqueeze(0)
186
 
@@ -224,7 +224,7 @@ def cam_inference(img, threshold, evt: gr.SelectData):
224
  gradients = {}
225
  activations = {}
226
 
227
- return create_cam_visualization_pil(img, cam, vis_threshold=threshold)
228
 
229
  def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
230
  """
@@ -275,10 +275,13 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
275
  """)
276
  original_image_state = gr.State() # stash a copy of the input image
277
  sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
 
278
  with gr.Row():
279
  with gr.Column():
280
  image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
281
- threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
 
 
282
  with gr.Column():
283
  tag_string = gr.Textbox(label="Tag String")
284
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
@@ -292,7 +295,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
292
  image_input.clear(
293
  fn=clear_image,
294
  inputs=[],
295
- outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state]
296
  )
297
 
298
  threshold_slider.input(
@@ -303,7 +306,19 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
303
 
304
  label_box.select(
305
  fn=cam_inference,
306
- inputs=[original_image_state, threshold_slider],
 
 
 
 
 
 
 
 
 
 
 
 
307
  outputs=[image_input]
308
  )
309
 
 
12
  import torchvision.transforms.functional as TF
13
  from huggingface_hub import hf_hub_download
14
  import numpy as np
15
+ import matplotlib.colormaps as cm
16
 
17
  class Fit(torch.nn.Module):
18
  def __init__(
 
178
  return text_no_impl, filtered_tag_score
179
 
180
  def clear_image():
181
+ return "", {}, None, {}, None
182
 
183
+ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
184
  target_tag = evt.value
185
  tensor = transform(img).unsqueeze(0)
186
 
 
224
  gradients = {}
225
  activations = {}
226
 
227
+ return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
228
 
229
  def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
230
  """
 
275
  """)
276
  original_image_state = gr.State() # stash a copy of the input image
277
  sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
278
+ cam_state = gr.State()
279
  with gr.Row():
280
  with gr.Column():
281
  image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
282
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
283
+ cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="CAM Threshold")
284
+ alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
285
  with gr.Column():
286
  tag_string = gr.Textbox(label="Tag String")
287
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
 
295
  image_input.clear(
296
  fn=clear_image,
297
  inputs=[],
298
+ outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
299
  )
300
 
301
  threshold_slider.input(
 
306
 
307
  label_box.select(
308
  fn=cam_inference,
309
+ inputs=[original_image_state, cam_slider, alpha_slider],
310
+ outputs=[image_input]
311
+ )
312
+
313
+ cam_slider.input(
314
+ fn=create_cam_visualization_pil,
315
+ inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
316
+ outputs=[image_input]
317
+ )
318
+
319
+ alpha_slider.input(
320
+ fn=create_cam_visualization_pil,
321
+ inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
322
  outputs=[image_input]
323
  )
324