Upload app.py
Browse files
app.py
CHANGED
@@ -268,10 +268,12 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
|
|
268 |
|
269 |
# --- Constants ---
|
270 |
REPO_ID = "cella110n/cl_tagger"
|
271 |
-
#
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
275 |
CACHE_DIR = "./model_cache"
|
276 |
|
277 |
# --- Global variables for paths (initialized at startup) ---
|
@@ -280,35 +282,76 @@ g_tag_mapping_path = None
|
|
280 |
g_labels_data = None
|
281 |
g_idx_to_tag = None
|
282 |
g_tag_to_category = None
|
|
|
283 |
|
284 |
# --- Initialization Function ---
|
285 |
-
def initialize_onnx_paths():
|
286 |
-
global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
hf_token = os.environ.get("HF_TOKEN")
|
289 |
try:
|
290 |
-
print(f"Attempting to download ONNX model: {
|
291 |
-
g_onnx_model_path = hf_hub_download(repo_id=REPO_ID, filename=
|
292 |
print(f"ONNX model path: {g_onnx_model_path}")
|
293 |
|
294 |
-
print(f"Attempting to download Tag mapping: {
|
295 |
-
g_tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=
|
296 |
print(f"Tag mapping path: {g_tag_mapping_path}")
|
297 |
|
298 |
print("Loading labels from mapping...")
|
299 |
g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
|
300 |
print(f"Labels loaded. Count: {len(g_labels_data.names)}")
|
|
|
|
|
301 |
|
302 |
except Exception as e:
|
303 |
print(f"Error during initialization: {e}")
|
304 |
import traceback; traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
# Raise Gradio error to make it visible in the UI
|
306 |
raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
|
307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
# --- Main Prediction Function (ONNX) ---
|
309 |
@spaces.GPU()
|
310 |
-
def predict_onnx(image_input, gen_threshold, char_threshold, output_mode):
|
311 |
-
print("--- predict_onnx function started (GPU worker) ---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
# --- 1. Ensure paths and labels are loaded ---
|
313 |
if g_onnx_model_path is None or g_labels_data is None:
|
314 |
message = "Error: Paths or labels not initialized. Check startup logs."
|
@@ -450,11 +493,16 @@ footer { display: none !important; }
|
|
450 |
with gr.Blocks(css=css) as demo:
|
451 |
gr.Markdown("# CL EVA02 ONNX Tagger")
|
452 |
gr.Markdown("Upload an image or paste an image URL to predict tags using the CL EVA02 Tagger model (ONNX), fine-tuned from [SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3).")
|
|
|
453 |
with gr.Row():
|
454 |
with gr.Column(scale=1):
|
455 |
image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
|
456 |
-
|
457 |
-
|
|
|
|
|
|
|
|
|
458 |
gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta Tag Threshold")
|
459 |
char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
|
460 |
output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
|
@@ -462,28 +510,37 @@ with gr.Blocks(css=css) as demo:
|
|
462 |
with gr.Column(scale=1):
|
463 |
output_tags = gr.Textbox(label="Predicted Tags", lines=10, interactive=False)
|
464 |
output_visualization = gr.Image(type="pil", label="Prediction Visualization", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
gr.Examples(
|
466 |
examples=[
|
467 |
-
["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", 0.55, 0.70, "Tags + Visualization"],
|
468 |
-
["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", 0.55, 0.70, "Tags Only"],
|
469 |
-
["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", 0.55, 0.70, "Tags + Visualization"],
|
470 |
-
["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", 0.55, 0.70, "Tags + Visualization"]
|
471 |
],
|
472 |
-
inputs=[image_input, gen_threshold, char_threshold, output_mode],
|
473 |
outputs=[output_tags, output_visualization],
|
474 |
fn=predict_onnx, # Use the ONNX prediction function
|
475 |
cache_examples=False # Disable caching for examples during testing
|
476 |
)
|
477 |
predict_button.click(
|
478 |
fn=predict_onnx, # Use the ONNX prediction function
|
479 |
-
inputs=[image_input, gen_threshold, char_threshold, output_mode],
|
480 |
outputs=[output_tags, output_visualization]
|
481 |
)
|
482 |
|
483 |
# --- Main Block ---
|
484 |
if __name__ == "__main__":
|
485 |
if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
|
486 |
-
# Initialize paths and labels at startup
|
487 |
-
initialize_onnx_paths()
|
488 |
# Launch Gradio app
|
489 |
demo.launch(share=True)
|
|
|
268 |
|
269 |
# --- Constants ---
|
270 |
REPO_ID = "cella110n/cl_tagger"
|
271 |
+
# Model options
|
272 |
+
MODEL_OPTIONS = {
|
273 |
+
"cl_eva02_tagger_v1_250426": "cl_eva02_tagger_v1_250426/model.onnx",
|
274 |
+
"cl_eva02_tagger_v1_250427": "cl_eva02_tagger_v1_250427/model.onnx"
|
275 |
+
}
|
276 |
+
DEFAULT_MODEL = "cl_eva02_tagger_v1_250426"
|
277 |
CACHE_DIR = "./model_cache"
|
278 |
|
279 |
# --- Global variables for paths (initialized at startup) ---
|
|
|
282 |
g_labels_data = None
|
283 |
g_idx_to_tag = None
|
284 |
g_tag_to_category = None
|
285 |
+
g_current_model = None
|
286 |
|
287 |
# --- Initialization Function ---
|
288 |
+
def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
|
289 |
+
global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category, g_current_model
|
290 |
+
|
291 |
+
if not model_choice in MODEL_OPTIONS:
|
292 |
+
print(f"Invalid model choice: {model_choice}, falling back to default: {DEFAULT_MODEL}")
|
293 |
+
model_choice = DEFAULT_MODEL
|
294 |
+
|
295 |
+
g_current_model = model_choice
|
296 |
+
model_dir = model_choice
|
297 |
+
onnx_filename = MODEL_OPTIONS[model_choice]
|
298 |
+
tag_mapping_filename = f"{model_dir}/tag_mapping.json"
|
299 |
+
|
300 |
+
print(f"Initializing ONNX paths and labels for model: {model_choice}...")
|
301 |
hf_token = os.environ.get("HF_TOKEN")
|
302 |
try:
|
303 |
+
print(f"Attempting to download ONNX model: {onnx_filename}")
|
304 |
+
g_onnx_model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
|
305 |
print(f"ONNX model path: {g_onnx_model_path}")
|
306 |
|
307 |
+
print(f"Attempting to download Tag mapping: {tag_mapping_filename}")
|
308 |
+
g_tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=tag_mapping_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
|
309 |
print(f"Tag mapping path: {g_tag_mapping_path}")
|
310 |
|
311 |
print("Loading labels from mapping...")
|
312 |
g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
|
313 |
print(f"Labels loaded. Count: {len(g_labels_data.names)}")
|
314 |
+
|
315 |
+
return True
|
316 |
|
317 |
except Exception as e:
|
318 |
print(f"Error during initialization: {e}")
|
319 |
import traceback; traceback.print_exc()
|
320 |
+
# Reset globals to force reinitialization
|
321 |
+
g_onnx_model_path = None
|
322 |
+
g_tag_mapping_path = None
|
323 |
+
g_labels_data = None
|
324 |
+
g_idx_to_tag = None
|
325 |
+
g_tag_to_category = None
|
326 |
+
g_current_model = None
|
327 |
# Raise Gradio error to make it visible in the UI
|
328 |
raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
|
329 |
|
330 |
+
# Function to handle model change
|
331 |
+
def change_model(model_choice):
|
332 |
+
try:
|
333 |
+
success = initialize_onnx_paths(model_choice)
|
334 |
+
if success:
|
335 |
+
return f"Model changed to: {model_choice}"
|
336 |
+
else:
|
337 |
+
return "Failed to change model. See logs for details."
|
338 |
+
except Exception as e:
|
339 |
+
return f"Error changing model: {str(e)}"
|
340 |
+
|
341 |
# --- Main Prediction Function (ONNX) ---
|
342 |
@spaces.GPU()
|
343 |
+
def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, output_mode):
|
344 |
+
print(f"--- predict_onnx function started (GPU worker) with model {model_choice} ---")
|
345 |
+
|
346 |
+
# Ensure current model matches selected model
|
347 |
+
global g_current_model
|
348 |
+
if g_current_model != model_choice:
|
349 |
+
print(f"Model mismatch! Current: {g_current_model}, Selected: {model_choice}. Reinitializing...")
|
350 |
+
try:
|
351 |
+
initialize_onnx_paths(model_choice)
|
352 |
+
except Exception as e:
|
353 |
+
return f"Error initializing model '{model_choice}': {str(e)}", None
|
354 |
+
|
355 |
# --- 1. Ensure paths and labels are loaded ---
|
356 |
if g_onnx_model_path is None or g_labels_data is None:
|
357 |
message = "Error: Paths or labels not initialized. Check startup logs."
|
|
|
493 |
with gr.Blocks(css=css) as demo:
|
494 |
gr.Markdown("# CL EVA02 ONNX Tagger")
|
495 |
gr.Markdown("Upload an image or paste an image URL to predict tags using the CL EVA02 Tagger model (ONNX), fine-tuned from [SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3).")
|
496 |
+
|
497 |
with gr.Row():
|
498 |
with gr.Column(scale=1):
|
499 |
image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
|
500 |
+
model_choice = gr.Dropdown(
|
501 |
+
choices=list(MODEL_OPTIONS.keys()),
|
502 |
+
value=DEFAULT_MODEL,
|
503 |
+
label="Model Version",
|
504 |
+
interactive=True
|
505 |
+
)
|
506 |
gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta Tag Threshold")
|
507 |
char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
|
508 |
output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
|
|
|
510 |
with gr.Column(scale=1):
|
511 |
output_tags = gr.Textbox(label="Predicted Tags", lines=10, interactive=False)
|
512 |
output_visualization = gr.Image(type="pil", label="Prediction Visualization", interactive=False)
|
513 |
+
|
514 |
+
# Handle model change
|
515 |
+
model_status = gr.Textbox(label="Model Status", interactive=False, visible=False)
|
516 |
+
model_choice.change(
|
517 |
+
fn=change_model,
|
518 |
+
inputs=[model_choice],
|
519 |
+
outputs=[model_status]
|
520 |
+
)
|
521 |
+
|
522 |
gr.Examples(
|
523 |
examples=[
|
524 |
+
["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"],
|
525 |
+
["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags Only"],
|
526 |
+
["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"],
|
527 |
+
["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"]
|
528 |
],
|
529 |
+
inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
|
530 |
outputs=[output_tags, output_visualization],
|
531 |
fn=predict_onnx, # Use the ONNX prediction function
|
532 |
cache_examples=False # Disable caching for examples during testing
|
533 |
)
|
534 |
predict_button.click(
|
535 |
fn=predict_onnx, # Use the ONNX prediction function
|
536 |
+
inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
|
537 |
outputs=[output_tags, output_visualization]
|
538 |
)
|
539 |
|
540 |
# --- Main Block ---
|
541 |
if __name__ == "__main__":
|
542 |
if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
|
543 |
+
# Initialize paths and labels at startup (with default model)
|
544 |
+
initialize_onnx_paths(DEFAULT_MODEL)
|
545 |
# Launch Gradio app
|
546 |
demo.launch(share=True)
|