Update app.py
Browse files
app.py
CHANGED
@@ -269,25 +269,24 @@ custom_css = """
|
|
269 |
) !important;
|
270 |
background-size: 100% 100% !important;
|
271 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
"""
|
273 |
|
274 |
with gr.Blocks(css=custom_css) as demo:
|
275 |
-
gr.Markdown(""
|
276 |
-
## Joint Tagger Project: JTP-PILOT² Demo **BETA**
|
277 |
-
This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
|
278 |
-
|
279 |
-
This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
|
280 |
-
|
281 |
-
Thanks to metal63 for providing initial code for attention visualization (click a tag in the tag list to try it out!)
|
282 |
-
|
283 |
-
Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
|
284 |
-
""")
|
285 |
original_image_state = gr.State() # stash a copy of the input image
|
286 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
287 |
cam_state = gr.State()
|
288 |
with gr.Row():
|
289 |
with gr.Column():
|
290 |
-
|
291 |
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
292 |
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
|
293 |
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
@@ -295,14 +294,24 @@ with gr.Blocks(css=custom_css) as demo:
|
|
295 |
tag_string = gr.Textbox(label="Tag String")
|
296 |
label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
|
297 |
|
298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
fn=run_classifier,
|
300 |
-
inputs=[
|
301 |
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state],
|
302 |
show_progress='minimal'
|
303 |
)
|
304 |
|
305 |
-
|
306 |
fn=clear_image,
|
307 |
inputs=[],
|
308 |
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
|
@@ -318,21 +327,21 @@ with gr.Blocks(css=custom_css) as demo:
|
|
318 |
label_box.select(
|
319 |
fn=cam_inference,
|
320 |
inputs=[original_image_state, cam_slider, alpha_slider],
|
321 |
-
outputs=[
|
322 |
show_progress='minimal'
|
323 |
)
|
324 |
|
325 |
cam_slider.input(
|
326 |
fn=create_cam_visualization_pil,
|
327 |
inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
|
328 |
-
outputs=[
|
329 |
show_progress='hidden'
|
330 |
)
|
331 |
|
332 |
alpha_slider.input(
|
333 |
fn=create_cam_visualization_pil,
|
334 |
inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
|
335 |
-
outputs=[
|
336 |
show_progress='hidden'
|
337 |
)
|
338 |
|
|
|
269 |
) !important;
|
270 |
background-size: 100% 100% !important;
|
271 |
}
|
272 |
+
#image_container-image {
|
273 |
+
width: 100%;
|
274 |
+
aspect-ratio: 1 / 1;
|
275 |
+
max-height: 100%;
|
276 |
+
}
|
277 |
+
#image_container img {
|
278 |
+
object-fit: contain !important;
|
279 |
+
}
|
280 |
"""
|
281 |
|
282 |
with gr.Blocks(css=custom_css) as demo:
|
283 |
+
gr.Markdown("## Joint Tagger Project: JTP-PILOT² Demo **BETA**")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
original_image_state = gr.State() # stash a copy of the input image
|
285 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
286 |
cam_state = gr.State()
|
287 |
with gr.Row():
|
288 |
with gr.Column():
|
289 |
+
image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
|
290 |
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
291 |
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
|
292 |
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
|
|
294 |
tag_string = gr.Textbox(label="Tag String")
|
295 |
label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
|
296 |
|
297 |
+
gr.Markdown("""
|
298 |
+
This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
|
299 |
+
|
300 |
+
This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
|
301 |
+
|
302 |
+
Thanks to metal63 for providing initial code for attention visualization (click a tag in the tag list to try it out!)
|
303 |
+
|
304 |
+
Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
|
305 |
+
""")
|
306 |
+
|
307 |
+
image.upload(
|
308 |
fn=run_classifier,
|
309 |
+
inputs=[image, threshold_slider],
|
310 |
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state],
|
311 |
show_progress='minimal'
|
312 |
)
|
313 |
|
314 |
+
image.clear(
|
315 |
fn=clear_image,
|
316 |
inputs=[],
|
317 |
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
|
|
|
327 |
label_box.select(
|
328 |
fn=cam_inference,
|
329 |
inputs=[original_image_state, cam_slider, alpha_slider],
|
330 |
+
outputs=[image, cam_state],
|
331 |
show_progress='minimal'
|
332 |
)
|
333 |
|
334 |
cam_slider.input(
|
335 |
fn=create_cam_visualization_pil,
|
336 |
inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
|
337 |
+
outputs=[image],
|
338 |
show_progress='hidden'
|
339 |
)
|
340 |
|
341 |
alpha_slider.input(
|
342 |
fn=create_cam_visualization_pil,
|
343 |
inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
|
344 |
+
outputs=[image],
|
345 |
show_progress='hidden'
|
346 |
)
|
347 |
|