drhead commited on
Commit
ccfb9cb
·
verified ·
1 Parent(s): c5429ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -269,25 +269,24 @@ custom_css = """
269
  ) !important;
270
  background-size: 100% 100% !important;
271
  }
 
 
 
 
 
 
 
 
272
  """
273
 
274
  with gr.Blocks(css=custom_css) as demo:
275
- gr.Markdown("""
276
- ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
277
- This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
278
-
279
- This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
280
-
281
- Thanks to metal63 for providing initial code for attention visualization (click a tag in the tag list to try it out!)
282
-
283
- Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
284
- """)
285
  original_image_state = gr.State() # stash a copy of the input image
286
  sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
287
  cam_state = gr.State()
288
  with gr.Row():
289
  with gr.Column():
290
- image_input = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', height=512, show_label=False)
291
  threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
292
  cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
293
  alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
@@ -295,14 +294,24 @@ with gr.Blocks(css=custom_css) as demo:
295
  tag_string = gr.Textbox(label="Tag String")
296
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
297
 
298
- image_input.upload(
 
 
 
 
 
 
 
 
 
 
299
  fn=run_classifier,
300
- inputs=[image_input, threshold_slider],
301
  outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state],
302
  show_progress='minimal'
303
  )
304
 
305
- image_input.clear(
306
  fn=clear_image,
307
  inputs=[],
308
  outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
@@ -318,21 +327,21 @@ with gr.Blocks(css=custom_css) as demo:
318
  label_box.select(
319
  fn=cam_inference,
320
  inputs=[original_image_state, cam_slider, alpha_slider],
321
- outputs=[image_input, cam_state],
322
  show_progress='minimal'
323
  )
324
 
325
  cam_slider.input(
326
  fn=create_cam_visualization_pil,
327
  inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
328
- outputs=[image_input],
329
  show_progress='hidden'
330
  )
331
 
332
  alpha_slider.input(
333
  fn=create_cam_visualization_pil,
334
  inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
335
- outputs=[image_input],
336
  show_progress='hidden'
337
  )
338
 
 
269
  ) !important;
270
  background-size: 100% 100% !important;
271
  }
272
+ #image_container-image {
273
+ width: 100%;
274
+ aspect-ratio: 1 / 1;
275
+ max-height: 100%;
276
+ }
277
+ #image_container img {
278
+ object-fit: contain !important;
279
+ }
280
  """
281
 
282
  with gr.Blocks(css=custom_css) as demo:
283
+ gr.Markdown("## Joint Tagger Project: JTP-PILOT² Demo **BETA**")
 
 
 
 
 
 
 
 
 
284
  original_image_state = gr.State() # stash a copy of the input image
285
  sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
286
  cam_state = gr.State()
287
  with gr.Row():
288
  with gr.Column():
289
+ image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
290
  threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
291
  cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
292
  alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
 
294
  tag_string = gr.Textbox(label="Tag String")
295
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
296
 
297
+ gr.Markdown("""
298
+ This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
299
+
300
+ This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
301
+
302
+ Thanks to metal63 for providing initial code for attention visualization (click a tag in the tag list to try it out!)
303
+
304
+ Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
305
+ """)
306
+
307
+ image.upload(
308
  fn=run_classifier,
309
+ inputs=[image, threshold_slider],
310
  outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state],
311
  show_progress='minimal'
312
  )
313
 
314
+ image.clear(
315
  fn=clear_image,
316
  inputs=[],
317
  outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
 
327
  label_box.select(
328
  fn=cam_inference,
329
  inputs=[original_image_state, cam_slider, alpha_slider],
330
+ outputs=[image, cam_state],
331
  show_progress='minimal'
332
  )
333
 
334
  cam_slider.input(
335
  fn=create_cam_visualization_pil,
336
  inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
337
+ outputs=[image],
338
  show_progress='hidden'
339
  )
340
 
341
  alpha_slider.input(
342
  fn=create_cam_visualization_pil,
343
  inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
344
+ outputs=[image],
345
  show_progress='hidden'
346
  )
347