add sliders for cam alpha/thresholding
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ from torchvision.transforms import InterpolationMode
|
|
12 |
import torchvision.transforms.functional as TF
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
import numpy as np
|
15 |
-
import matplotlib.
|
16 |
|
17 |
class Fit(torch.nn.Module):
|
18 |
def __init__(
|
@@ -178,9 +178,9 @@ def create_tags(threshold, sorted_tag_score: dict):
|
|
178 |
return text_no_impl, filtered_tag_score
|
179 |
|
180 |
def clear_image():
|
181 |
-
return "", {}, None, {}
|
182 |
|
183 |
-
def cam_inference(img, threshold, evt: gr.SelectData):
|
184 |
target_tag = evt.value
|
185 |
tensor = transform(img).unsqueeze(0)
|
186 |
|
@@ -224,7 +224,7 @@ def cam_inference(img, threshold, evt: gr.SelectData):
|
|
224 |
gradients = {}
|
225 |
activations = {}
|
226 |
|
227 |
-
return create_cam_visualization_pil(img, cam, vis_threshold=threshold)
|
228 |
|
229 |
def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
|
230 |
"""
|
@@ -275,10 +275,13 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
275 |
""")
|
276 |
original_image_state = gr.State() # stash a copy of the input image
|
277 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
|
|
278 |
with gr.Row():
|
279 |
with gr.Column():
|
280 |
image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
|
281 |
-
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
|
|
|
|
|
282 |
with gr.Column():
|
283 |
tag_string = gr.Textbox(label="Tag String")
|
284 |
label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
|
@@ -292,7 +295,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
292 |
image_input.clear(
|
293 |
fn=clear_image,
|
294 |
inputs=[],
|
295 |
-
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state]
|
296 |
)
|
297 |
|
298 |
threshold_slider.input(
|
@@ -303,7 +306,19 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
303 |
|
304 |
label_box.select(
|
305 |
fn=cam_inference,
|
306 |
-
inputs=[original_image_state,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
outputs=[image_input]
|
308 |
)
|
309 |
|
|
|
12 |
import torchvision.transforms.functional as TF
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
import numpy as np
|
15 |
+
import matplotlib.colormaps as cm
|
16 |
|
17 |
class Fit(torch.nn.Module):
|
18 |
def __init__(
|
|
|
178 |
return text_no_impl, filtered_tag_score
|
179 |
|
180 |
def clear_image():
|
181 |
+
return "", {}, None, {}, None
|
182 |
|
183 |
+
def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
184 |
target_tag = evt.value
|
185 |
tensor = transform(img).unsqueeze(0)
|
186 |
|
|
|
224 |
gradients = {}
|
225 |
activations = {}
|
226 |
|
227 |
+
return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
|
228 |
|
229 |
def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
|
230 |
"""
|
|
|
275 |
""")
|
276 |
original_image_state = gr.State() # stash a copy of the input image
|
277 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
278 |
+
cam_state = gr.State()
|
279 |
with gr.Row():
|
280 |
with gr.Column():
|
281 |
image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
|
282 |
+
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
283 |
+
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="CAM Threshold")
|
284 |
+
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
285 |
with gr.Column():
|
286 |
tag_string = gr.Textbox(label="Tag String")
|
287 |
label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
|
|
|
295 |
image_input.clear(
|
296 |
fn=clear_image,
|
297 |
inputs=[],
|
298 |
+
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
|
299 |
)
|
300 |
|
301 |
threshold_slider.input(
|
|
|
306 |
|
307 |
label_box.select(
|
308 |
fn=cam_inference,
|
309 |
+
inputs=[original_image_state, cam_slider, alpha_slider],
|
310 |
+
outputs=[image_input]
|
311 |
+
)
|
312 |
+
|
313 |
+
cam_slider.input(
|
314 |
+
fn=create_cam_visualization_pil,
|
315 |
+
inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
|
316 |
+
outputs=[image_input]
|
317 |
+
)
|
318 |
+
|
319 |
+
alpha_slider.input(
|
320 |
+
fn=create_cam_visualization_pil,
|
321 |
+
inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
|
322 |
outputs=[image_input]
|
323 |
)
|
324 |
|