CyberWaifu commited on
Commit
11cfce1
·
verified ·
1 Parent(s): dcd003b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -23
app.py CHANGED
@@ -10,7 +10,7 @@ MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
10
  MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
12
  IMAGE_SIZE = (512, 512)
13
- DEFAULT_THRESHOLD = 0.35
14
 
15
  # Download model and metadata from Hugging Face Hub
16
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
@@ -44,15 +44,13 @@ def run_inference(pil_image: Image.Image) -> np.ndarray:
44
  _, refined_logits = session.run(None, {input_name: input_tensor})
45
  return refined_logits[0]
46
 
47
- def get_tags(refined_logits: np.ndarray, metadata: dict, custom_threshold: float = None):
48
  """
49
  Compute probabilities from logits and collect tag predictions.
50
 
51
- If custom_threshold is provided, it overrides category-specific thresholds.
52
-
53
  Returns:
54
  results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
55
- prompt_tags_by_cat: Dictionary for prompt-style output with keys: artist, character, general.
56
  all_artist_tags: All artist tags (with probabilities) regardless of threshold.
57
  """
58
  probs = 1 / (1 + np.exp(-refined_logits))
@@ -67,8 +65,7 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, custom_threshold: float
67
  for idx, prob in enumerate(probs):
68
  tag = idx_to_tag[str(idx)]
69
  cat = tag_to_category.get(tag, "unknown")
70
- # Use custom threshold if provided; otherwise, use metadata threshold or default.
71
- thresh = custom_threshold if custom_threshold is not None else category_thresholds.get(cat, DEFAULT_THRESHOLD)
72
  if cat == "artist":
73
  all_artist_tags.append((tag, float(prob)))
74
  if float(prob) >= thresh:
@@ -83,6 +80,7 @@ def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
83
 
84
  Returns a comma-separated string of escaped tags.
85
  """
 
86
  for cat in prompt_tags_by_cat:
87
  prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
88
 
@@ -91,7 +89,7 @@ def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
91
  general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
92
  prompt_tags = artist_tags + character_tags + general_tags
93
 
94
- # Ensure at least one artist tag appears even if none pass the threshold
95
  if not artist_tags and all_artist_tags:
96
  best_artist_tag, _ = max(all_artist_tags, key=lambda item: item[1])
97
  prompt_tags.insert(0, escape_tag(best_artist_tag))
@@ -117,20 +115,20 @@ def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
117
  lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
118
  for tag, prob in tag_list:
119
  lines.append(f"- {escape_tag(tag)} (Prob: {prob:.3f})")
120
- lines.append("")
121
  return "\n".join(lines)
122
 
123
  def tag_image(pil_image: Image.Image, output_format: str, threshold: float) -> str:
124
  """
125
  Run inference on the image and return formatted tags based on the chosen output format.
126
 
127
- The threshold slider value overrides category-specific thresholds if provided.
128
  """
129
  if pil_image is None:
130
  return "Please upload an image."
131
 
132
  refined_logits = run_inference(pil_image)
133
- results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(refined_logits, metadata, custom_threshold=threshold)
134
 
135
  if output_format == "Prompt-style Tags":
136
  return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
@@ -144,10 +142,10 @@ with demo:
144
  gr.Markdown(
145
  "# 🏷️ Camie Tagger – Anime Image Tagging\n"
146
  "This demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. "
147
- "Upload an image and click **Tag Image** to see predictions."
148
  )
149
  gr.Markdown(
150
- "*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. "
151
  "You can choose a concise prompt-style output or a detailed category-wise breakdown.)*"
152
  )
153
  with gr.Row():
@@ -158,23 +156,20 @@ with demo:
158
  value="Prompt-style Tags",
159
  label="Output Format"
160
  )
161
- # Slider to modify the global threshold value
162
  threshold_slider = gr.Slider(
163
- minimum=0,
164
- maximum=1,
165
  step=0.05,
166
  value=DEFAULT_THRESHOLD,
167
- label="Global Threshold"
168
  )
169
  tag_button = gr.Button("🔍 Tag Image")
170
  with gr.Column():
171
- output_box = gr.Markdown("")
172
 
173
- tag_button.click(
174
- fn=tag_image,
175
- inputs=[image_in, format_choice, threshold_slider],
176
- outputs=output_box
177
- )
178
 
179
  gr.Markdown(
180
  "----\n"
 
10
  MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
12
  IMAGE_SIZE = (512, 512)
13
+ DEFAULT_THRESHOLD = 0.35 # Default value if slider is not used
14
 
15
  # Download model and metadata from Hugging Face Hub
16
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
 
44
  _, refined_logits = session.run(None, {input_name: input_tensor})
45
  return refined_logits[0]
46
 
47
+ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float):
48
  """
49
  Compute probabilities from logits and collect tag predictions.
50
 
 
 
51
  Returns:
52
  results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
53
+ prompt_tags_by_cat: Dictionary for prompt-style output (artist, character, general).
54
  all_artist_tags: All artist tags (with probabilities) regardless of threshold.
55
  """
56
  probs = 1 / (1 + np.exp(-refined_logits))
 
65
  for idx, prob in enumerate(probs):
66
  tag = idx_to_tag[str(idx)]
67
  cat = tag_to_category.get(tag, "unknown")
68
+ thresh = category_thresholds.get(cat, default_threshold)
 
69
  if cat == "artist":
70
  all_artist_tags.append((tag, float(prob)))
71
  if float(prob) >= thresh:
 
80
 
81
  Returns a comma-separated string of escaped tags.
82
  """
83
+ # Sort tags within each category by probability (descending)
84
  for cat in prompt_tags_by_cat:
85
  prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
86
 
 
89
  general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
90
  prompt_tags = artist_tags + character_tags + general_tags
91
 
92
+ # Ensure at least one artist tag appears if available, even if below threshold
93
  if not artist_tags and all_artist_tags:
94
  best_artist_tag, _ = max(all_artist_tags, key=lambda item: item[1])
95
  prompt_tags.insert(0, escape_tag(best_artist_tag))
 
115
  lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
116
  for tag, prob in tag_list:
117
  lines.append(f"- {escape_tag(tag)} (Prob: {prob:.3f})")
118
+ lines.append("") # blank line between categories
119
  return "\n".join(lines)
120
 
121
  def tag_image(pil_image: Image.Image, output_format: str, threshold: float) -> str:
122
  """
123
  Run inference on the image and return formatted tags based on the chosen output format.
124
 
125
+ The slider value (threshold) overrides the default threshold for tag selection.
126
  """
127
  if pil_image is None:
128
  return "Please upload an image."
129
 
130
  refined_logits = run_inference(pil_image)
131
+ results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(refined_logits, metadata, default_threshold=threshold)
132
 
133
  if output_format == "Prompt-style Tags":
134
  return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
 
142
  gr.Markdown(
143
  "# 🏷️ Camie Tagger – Anime Image Tagging\n"
144
  "This demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. "
145
+ "Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
146
  )
147
  gr.Markdown(
148
+ "*(Note: The model predicts a large number of tags across categories like character, general, artist, etc. "
149
  "You can choose a concise prompt-style output or a detailed category-wise breakdown.)*"
150
  )
151
  with gr.Row():
 
156
  value="Prompt-style Tags",
157
  label="Output Format"
158
  )
159
+ # Slider to modify the default threshold value used in inference.
160
  threshold_slider = gr.Slider(
161
+ minimum=0.0,
162
+ maximum=1.0,
163
  step=0.05,
164
  value=DEFAULT_THRESHOLD,
165
+ label="Threshold"
166
  )
167
  tag_button = gr.Button("🔍 Tag Image")
168
  with gr.Column():
169
+ output_box = gr.Markdown("") # Markdown output for formatted results
170
 
171
+ # Pass the threshold_slider value into the tag_image function
172
+ tag_button.click(fn=tag_image, inputs=[image_in, format_choice, threshold_slider], outputs=output_box)
 
 
 
173
 
174
  gr.Markdown(
175
  "----\n"