cella110n commited on
Commit
601532d
·
verified ·
1 Parent(s): 43f3f8c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -23
app.py CHANGED
@@ -268,10 +268,12 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
268
 
269
  # --- Constants ---
270
  REPO_ID = "cella110n/cl_tagger"
271
- # Use the specified ONNX model filename
272
- ONNX_FILENAME = "cl_eva02_tagger_v1_250426/model.onnx"
273
- # Correct the tag mapping path to match the ONNX model's directory
274
- TAG_MAPPING_FILENAME = "cl_eva02_tagger_v1_250426/tag_mapping.json"
 
 
275
  CACHE_DIR = "./model_cache"
276
 
277
  # --- Global variables for paths (initialized at startup) ---
@@ -280,35 +282,76 @@ g_tag_mapping_path = None
280
  g_labels_data = None
281
  g_idx_to_tag = None
282
  g_tag_to_category = None
 
283
 
284
  # --- Initialization Function ---
285
- def initialize_onnx_paths():
286
- global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category
287
- print("Initializing ONNX paths and labels...")
 
 
 
 
 
 
 
 
 
 
288
  hf_token = os.environ.get("HF_TOKEN")
289
  try:
290
- print(f"Attempting to download ONNX model: {ONNX_FILENAME}")
291
- g_onnx_model_path = hf_hub_download(repo_id=REPO_ID, filename=ONNX_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
292
  print(f"ONNX model path: {g_onnx_model_path}")
293
 
294
- print(f"Attempting to download Tag mapping: {TAG_MAPPING_FILENAME}")
295
- g_tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
296
  print(f"Tag mapping path: {g_tag_mapping_path}")
297
 
298
  print("Loading labels from mapping...")
299
  g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
300
  print(f"Labels loaded. Count: {len(g_labels_data.names)}")
 
 
301
 
302
  except Exception as e:
303
  print(f"Error during initialization: {e}")
304
  import traceback; traceback.print_exc()
 
 
 
 
 
 
 
305
  # Raise Gradio error to make it visible in the UI
306
  raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
307
 
 
 
 
 
 
 
 
 
 
 
 
308
  # --- Main Prediction Function (ONNX) ---
309
  @spaces.GPU()
310
- def predict_onnx(image_input, gen_threshold, char_threshold, output_mode):
311
- print("--- predict_onnx function started (GPU worker) ---")
 
 
 
 
 
 
 
 
 
 
312
  # --- 1. Ensure paths and labels are loaded ---
313
  if g_onnx_model_path is None or g_labels_data is None:
314
  message = "Error: Paths or labels not initialized. Check startup logs."
@@ -450,11 +493,16 @@ footer { display: none !important; }
450
  with gr.Blocks(css=css) as demo:
451
  gr.Markdown("# CL EVA02 ONNX Tagger")
452
  gr.Markdown("Upload an image or paste an image URL to predict tags using the CL EVA02 Tagger model (ONNX), fine-tuned from [SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3).")
 
453
  with gr.Row():
454
  with gr.Column(scale=1):
455
  image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
456
- # Add back URL input capability if desired (needs JS or separate component)
457
- # gr.HTML("<div id='url-input-container'></div>")
 
 
 
 
458
  gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta Tag Threshold")
459
  char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
460
  output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
@@ -462,28 +510,37 @@ with gr.Blocks(css=css) as demo:
462
  with gr.Column(scale=1):
463
  output_tags = gr.Textbox(label="Predicted Tags", lines=10, interactive=False)
464
  output_visualization = gr.Image(type="pil", label="Prediction Visualization", interactive=False)
 
 
 
 
 
 
 
 
 
465
  gr.Examples(
466
  examples=[
467
- ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", 0.55, 0.70, "Tags + Visualization"],
468
- ["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", 0.55, 0.70, "Tags Only"],
469
- ["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", 0.55, 0.70, "Tags + Visualization"],
470
- ["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", 0.55, 0.70, "Tags + Visualization"]
471
  ],
472
- inputs=[image_input, gen_threshold, char_threshold, output_mode],
473
  outputs=[output_tags, output_visualization],
474
  fn=predict_onnx, # Use the ONNX prediction function
475
  cache_examples=False # Disable caching for examples during testing
476
  )
477
  predict_button.click(
478
  fn=predict_onnx, # Use the ONNX prediction function
479
- inputs=[image_input, gen_threshold, char_threshold, output_mode],
480
  outputs=[output_tags, output_visualization]
481
  )
482
 
483
  # --- Main Block ---
484
  if __name__ == "__main__":
485
  if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
486
- # Initialize paths and labels at startup
487
- initialize_onnx_paths()
488
  # Launch Gradio app
489
  demo.launch(share=True)
 
268
 
269
  # --- Constants ---
270
  REPO_ID = "cella110n/cl_tagger"
271
+ # Model options
272
+ MODEL_OPTIONS = {
273
+ "cl_eva02_tagger_v1_250426": "cl_eva02_tagger_v1_250426/model.onnx",
274
+ "cl_eva02_tagger_v1_250427": "cl_eva02_tagger_v1_250427/model.onnx"
275
+ }
276
+ DEFAULT_MODEL = "cl_eva02_tagger_v1_250426"
277
  CACHE_DIR = "./model_cache"
278
 
279
  # --- Global variables for paths (initialized at startup) ---
 
282
  g_labels_data = None
283
  g_idx_to_tag = None
284
  g_tag_to_category = None
285
+ g_current_model = None
286
 
287
  # --- Initialization Function ---
288
+ def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
289
+ global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category, g_current_model
290
+
291
+ if not model_choice in MODEL_OPTIONS:
292
+ print(f"Invalid model choice: {model_choice}, falling back to default: {DEFAULT_MODEL}")
293
+ model_choice = DEFAULT_MODEL
294
+
295
+ g_current_model = model_choice
296
+ model_dir = model_choice
297
+ onnx_filename = MODEL_OPTIONS[model_choice]
298
+ tag_mapping_filename = f"{model_dir}/tag_mapping.json"
299
+
300
+ print(f"Initializing ONNX paths and labels for model: {model_choice}...")
301
  hf_token = os.environ.get("HF_TOKEN")
302
  try:
303
+ print(f"Attempting to download ONNX model: {onnx_filename}")
304
+ g_onnx_model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
305
  print(f"ONNX model path: {g_onnx_model_path}")
306
 
307
+ print(f"Attempting to download Tag mapping: {tag_mapping_filename}")
308
+ g_tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=tag_mapping_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
309
  print(f"Tag mapping path: {g_tag_mapping_path}")
310
 
311
  print("Loading labels from mapping...")
312
  g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
313
  print(f"Labels loaded. Count: {len(g_labels_data.names)}")
314
+
315
+ return True
316
 
317
  except Exception as e:
318
  print(f"Error during initialization: {e}")
319
  import traceback; traceback.print_exc()
320
+ # Reset globals to force reinitialization
321
+ g_onnx_model_path = None
322
+ g_tag_mapping_path = None
323
+ g_labels_data = None
324
+ g_idx_to_tag = None
325
+ g_tag_to_category = None
326
+ g_current_model = None
327
  # Raise Gradio error to make it visible in the UI
328
  raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
329
 
330
+ # Function to handle model change
331
+ def change_model(model_choice):
332
+ try:
333
+ success = initialize_onnx_paths(model_choice)
334
+ if success:
335
+ return f"Model changed to: {model_choice}"
336
+ else:
337
+ return "Failed to change model. See logs for details."
338
+ except Exception as e:
339
+ return f"Error changing model: {str(e)}"
340
+
341
  # --- Main Prediction Function (ONNX) ---
342
  @spaces.GPU()
343
+ def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, output_mode):
344
+ print(f"--- predict_onnx function started (GPU worker) with model {model_choice} ---")
345
+
346
+ # Ensure current model matches selected model
347
+ global g_current_model
348
+ if g_current_model != model_choice:
349
+ print(f"Model mismatch! Current: {g_current_model}, Selected: {model_choice}. Reinitializing...")
350
+ try:
351
+ initialize_onnx_paths(model_choice)
352
+ except Exception as e:
353
+ return f"Error initializing model '{model_choice}': {str(e)}", None
354
+
355
  # --- 1. Ensure paths and labels are loaded ---
356
  if g_onnx_model_path is None or g_labels_data is None:
357
  message = "Error: Paths or labels not initialized. Check startup logs."
 
493
  with gr.Blocks(css=css) as demo:
494
  gr.Markdown("# CL EVA02 ONNX Tagger")
495
  gr.Markdown("Upload an image or paste an image URL to predict tags using the CL EVA02 Tagger model (ONNX), fine-tuned from [SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3).")
496
+
497
  with gr.Row():
498
  with gr.Column(scale=1):
499
  image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
500
+ model_choice = gr.Dropdown(
501
+ choices=list(MODEL_OPTIONS.keys()),
502
+ value=DEFAULT_MODEL,
503
+ label="Model Version",
504
+ interactive=True
505
+ )
506
  gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta Tag Threshold")
507
  char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
508
  output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
 
510
  with gr.Column(scale=1):
511
  output_tags = gr.Textbox(label="Predicted Tags", lines=10, interactive=False)
512
  output_visualization = gr.Image(type="pil", label="Prediction Visualization", interactive=False)
513
+
514
+ # Handle model change
515
+ model_status = gr.Textbox(label="Model Status", interactive=False, visible=False)
516
+ model_choice.change(
517
+ fn=change_model,
518
+ inputs=[model_choice],
519
+ outputs=[model_status]
520
+ )
521
+
522
  gr.Examples(
523
  examples=[
524
+ ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"],
525
+ ["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags Only"],
526
+ ["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"],
527
+ ["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"]
528
  ],
529
+ inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
530
  outputs=[output_tags, output_visualization],
531
  fn=predict_onnx, # Use the ONNX prediction function
532
  cache_examples=False # Disable caching for examples during testing
533
  )
534
  predict_button.click(
535
  fn=predict_onnx, # Use the ONNX prediction function
536
+ inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
537
  outputs=[output_tags, output_visualization]
538
  )
539
 
540
  # --- Main Block ---
541
  if __name__ == "__main__":
542
  if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
543
+ # Initialize paths and labels at startup (with default model)
544
+ initialize_onnx_paths(DEFAULT_MODEL)
545
  # Launch Gradio app
546
  demo.launch(share=True)