CyberWaifu commited on
Commit
dda048a
·
verified ·
1 Parent(s): 912174d
Files changed (1) hide show
  1. app.py +45 -19
app.py CHANGED
@@ -10,7 +10,7 @@ MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
10
  MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
12
  IMAGE_SIZE = (512, 512)
13
- DEFAULT_THRESHOLD = 0.35 # Default value if slider is not used
14
 
15
  # Download model and metadata from Hugging Face Hub
16
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
@@ -35,7 +35,6 @@ def preprocess_image(pil_image: Image.Image) -> np.ndarray:
35
  def run_inference(pil_image: Image.Image) -> np.ndarray:
36
  """
37
  Preprocess the image and run the ONNX model inference.
38
-
39
  Returns the refined logits as a numpy array.
40
  """
41
  input_tensor = preprocess_image(pil_image)
@@ -44,13 +43,26 @@ def run_inference(pil_image: Image.Image) -> np.ndarray:
44
  _, refined_logits = session.run(None, {input_name: input_tensor})
45
  return refined_logits[0]
46
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float):
48
  """
49
  Compute probabilities from logits and collect tag predictions.
50
 
51
  Returns:
52
- results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
53
- prompt_tags_by_cat: Dictionary for prompt-style output (character, general).
 
54
  all_artist_tags: All artist tags (with probabilities) regardless of threshold.
55
  """
56
  probs = 1 / (1 + np.exp(-refined_logits))
@@ -59,7 +71,7 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
59
  category_thresholds = metadata.get("category_thresholds", {})
60
 
61
  results_by_cat = {}
62
- # For prompt style, only include character and general tags (artists handled separately)
63
  prompt_tags_by_cat = {"character": [], "general": []}
64
  all_artist_tags = []
65
 
@@ -78,7 +90,8 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
78
  def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
79
  """
80
  Format the tags for prompt-style output.
81
- Only the top artist tag is shown (regardless of threshold), and all character and general tags are shown.
 
82
 
83
  Returns a comma-separated string of escaped tags.
84
  """
@@ -106,13 +119,12 @@ def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
106
  def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
107
  """
108
  Format the tags for detailed output.
109
-
110
  Returns a Markdown-formatted string listing tags by category.
111
  """
112
  if not results_by_cat:
113
  return "No tags predicted for this image."
114
 
115
- # Include an artist tag even if below threshold
116
  if "artist" not in results_by_cat and all_artist_tags:
117
  best_artist_tag, best_artist_prob = max(all_artist_tags, key=lambda item: item[1])
118
  results_by_cat["artist"] = [(best_artist_tag, best_artist_prob)]
@@ -126,17 +138,24 @@ def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
126
  lines.append("") # blank line between categories
127
  return "\n".join(lines)
128
 
129
- def tag_image(pil_image: Image.Image, output_format: str, threshold: float) -> str:
130
  """
131
  Run inference on the image and return formatted tags based on the chosen output format.
132
-
133
- The slider value (threshold) overrides the default threshold for tag selection.
134
  """
135
  if pil_image is None:
136
  return "Please upload an image."
137
 
138
  refined_logits = run_inference(pil_image)
139
- results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(refined_logits, metadata, default_threshold=threshold)
 
 
 
 
 
 
 
140
 
141
  if output_format == "Prompt-style Tags":
142
  return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
@@ -153,7 +172,8 @@ with demo:
153
  "Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
154
  )
155
  gr.Markdown(
156
- "*(Note: In prompt-style output, only the top artist tag is displayed along with all character and general tags.)*"
 
157
  )
158
  with gr.Row():
159
  with gr.Column():
@@ -163,26 +183,32 @@ with demo:
163
  value="Prompt-style Tags",
164
  label="Output Format"
165
  )
166
- # Slider to modify the default threshold value used in inference.
167
  threshold_slider = gr.Slider(
168
  minimum=0.0,
169
  maximum=1.0,
170
  step=0.05,
171
  value=DEFAULT_THRESHOLD,
172
- label="Threshold"
 
 
 
 
173
  )
174
  tag_button = gr.Button("🔍 Tag Image")
175
  with gr.Column():
176
  output_box = gr.Markdown("") # Markdown output for formatted results
177
 
178
- # Pass the threshold_slider value into the tag_image function
179
- tag_button.click(fn=tag_image, inputs=[image_in, format_choice, threshold_slider], outputs=output_box)
 
 
 
 
180
 
181
  gr.Markdown(
182
  "----\n"
183
  "**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • "
184
- "**Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • "
185
- "**ONNX Runtime:** for efficient CPU inference • "
186
  "*Demo built with Gradio Blocks.*"
187
  )
188
 
 
10
  MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
12
  IMAGE_SIZE = (512, 512)
13
+ DEFAULT_THRESHOLD = 0.35 # Default threshold if slider is used
14
 
15
  # Download model and metadata from Hugging Face Hub
16
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
 
35
  def run_inference(pil_image: Image.Image) -> np.ndarray:
36
  """
37
  Preprocess the image and run the ONNX model inference.
 
38
  Returns the refined logits as a numpy array.
39
  """
40
  input_tensor = preprocess_image(pil_image)
 
43
  _, refined_logits = session.run(None, {input_name: input_tensor})
44
  return refined_logits[0]
45
 
46
+ def mcut_threshold(probs: np.ndarray) -> float:
47
+ """
48
+ Compute the MCut threshold from the given probabilities.
49
+ Uses the MCut method described in:
50
+ Largeron, C., Moulin, C., & Gery, M. (2012).
51
+ """
52
+ sorted_probs = probs[probs.argsort()[::-1]]
53
+ diffs = sorted_probs[:-1] - sorted_probs[1:]
54
+ t = diffs.argmax()
55
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
56
+ return thresh
57
+
58
  def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float):
59
  """
60
  Compute probabilities from logits and collect tag predictions.
61
 
62
  Returns:
63
+ results_by_cat: Dictionary mapping each category to a list of (tag, probability)
64
+ above its threshold.
65
+ prompt_tags_by_cat: Dictionary for prompt-style output (character and general tags).
66
  all_artist_tags: All artist tags (with probabilities) regardless of threshold.
67
  """
68
  probs = 1 / (1 + np.exp(-refined_logits))
 
71
  category_thresholds = metadata.get("category_thresholds", {})
72
 
73
  results_by_cat = {}
74
+ # For prompt-style output, only include character and general tags (artists handled separately)
75
  prompt_tags_by_cat = {"character": [], "general": []}
76
  all_artist_tags = []
77
 
 
90
  def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
91
  """
92
  Format the tags for prompt-style output.
93
+ Only the top artist tag is shown (regardless of threshold),
94
+ and all character and general tags are shown.
95
 
96
  Returns a comma-separated string of escaped tags.
97
  """
 
119
  def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
120
  """
121
  Format the tags for detailed output.
 
122
  Returns a Markdown-formatted string listing tags by category.
123
  """
124
  if not results_by_cat:
125
  return "No tags predicted for this image."
126
 
127
+ # Include an artist tag even if below threshold.
128
  if "artist" not in results_by_cat and all_artist_tags:
129
  best_artist_tag, best_artist_prob = max(all_artist_tags, key=lambda item: item[1])
130
  results_by_cat["artist"] = [(best_artist_tag, best_artist_prob)]
 
138
  lines.append("") # blank line between categories
139
  return "\n".join(lines)
140
 
141
+ def tag_image(pil_image: Image.Image, output_format: str, threshold: float, mcut_enabled: bool) -> str:
142
  """
143
  Run inference on the image and return formatted tags based on the chosen output format.
144
+ The slider value (threshold) normally overrides the default threshold for tag selection.
145
+ If mcut_enabled is True, compute a new threshold using MCut from all probabilities.
146
  """
147
  if pil_image is None:
148
  return "Please upload an image."
149
 
150
  refined_logits = run_inference(pil_image)
151
+ # Compute probabilities from logits
152
+ probs = 1 / (1 + np.exp(-refined_logits))
153
+ # If MCut is enabled, override the threshold using the MCut method.
154
+ computed_threshold = mcut_threshold(probs) if mcut_enabled else threshold
155
+
156
+ results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(
157
+ refined_logits, metadata, default_threshold=computed_threshold
158
+ )
159
 
160
  if output_format == "Prompt-style Tags":
161
  return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
 
172
  "Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
173
  )
174
  gr.Markdown(
175
+ "*(Note: In prompt-style output, only the top artist tag is displayed along with all character and general tags. "
176
+ "If MCut is enabled, its computed threshold overrides the default slider value.)*"
177
  )
178
  with gr.Row():
179
  with gr.Column():
 
183
  value="Prompt-style Tags",
184
  label="Output Format"
185
  )
 
186
  threshold_slider = gr.Slider(
187
  minimum=0.0,
188
  maximum=1.0,
189
  step=0.05,
190
  value=DEFAULT_THRESHOLD,
191
+ label="Default Threshold"
192
+ )
193
+ mcut_checkbox = gr.Checkbox(
194
+ value=False,
195
+ label="Use MCut threshold"
196
  )
197
  tag_button = gr.Button("🔍 Tag Image")
198
  with gr.Column():
199
  output_box = gr.Markdown("") # Markdown output for formatted results
200
 
201
+ # Pass the threshold_slider and mcut_checkbox values into the tag_image function
202
+ tag_button.click(
203
+ fn=tag_image,
204
+ inputs=[image_in, format_choice, threshold_slider, mcut_checkbox],
205
+ outputs=output_box
206
+ )
207
 
208
  gr.Markdown(
209
  "----\n"
210
  "**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • "
211
+ "**Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • **ONNX Runtime:** for efficient CPU inference • "
 
212
  "*Demo built with Gradio Blocks.*"
213
  )
214