CyberWaifu commited on
Commit
6e46e05
·
verified ·
1 Parent(s): 8da1280

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  from PIL import Image
5
  import json
6
  from huggingface_hub import hf_hub_download
 
7
 
8
  # --- Constants ---
9
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
@@ -72,7 +73,7 @@ def format_prompt_style_output(results_by_cat: dict, all_artist_tags_probs: list
72
 
73
  artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
74
  character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
75
- general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_with_probs]
76
 
77
  prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
78
 
@@ -142,11 +143,13 @@ def create_gradio_interface(session: ort.InferenceSession, metadata: dict) -> gr
142
  with gr.Column():
143
  output_box = gr.Markdown("")
144
 
 
 
 
145
  tag_button.click(
146
- fn=tag_image,
147
  inputs=[image_in, format_choice],
148
  outputs=output_box,
149
- extra_args=[session, metadata] # Pass session and metadata as extra arguments
150
  )
151
 
152
  gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   •   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   •   **ONNX Runtime:** for efficient CPU inference   •   *Demo built with Gradio Blocks.*")
 
4
  from PIL import Image
5
  import json
6
  from huggingface_hub import hf_hub_download
7
+ from functools import partial # Import partial
8
 
9
  # --- Constants ---
10
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
 
73
 
74
  artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
75
  character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
76
+ general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_probs]
77
 
78
  prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
79
 
 
143
  with gr.Column():
144
  output_box = gr.Markdown("")
145
 
146
+ # Create a partial function with session and metadata pre-filled
147
+ tag_image_with_model = partial(tag_image, session=session, metadata=metadata)
148
+
149
  tag_button.click(
150
+ fn=tag_image_with_model, # Use the partially applied function
151
  inputs=[image_in, format_choice],
152
  outputs=output_box,
 
153
  )
154
 
155
  gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   •   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   •   **ONNX Runtime:** for efficient CPU inference   •   *Demo built with Gradio Blocks.*")