CyberWaifu commited on
Commit
dcd003b
·
verified ·
1 Parent(s): b317da6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -11
app.py CHANGED
@@ -44,13 +44,15 @@ def run_inference(pil_image: Image.Image) -> np.ndarray:
44
  _, refined_logits = session.run(None, {input_name: input_tensor})
45
  return refined_logits[0]
46
 
47
- def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float = DEFAULT_THRESHOLD):
48
  """
49
  Compute probabilities from logits and collect tag predictions.
50
 
 
 
51
  Returns:
52
  results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
53
- prompt_tags_by_cat: Similar dictionary but only for prompt-style categories (artist, character, general).
54
  all_artist_tags: All artist tags (with probabilities) regardless of threshold.
55
  """
56
  probs = 1 / (1 + np.exp(-refined_logits))
@@ -65,7 +67,8 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
65
  for idx, prob in enumerate(probs):
66
  tag = idx_to_tag[str(idx)]
67
  cat = tag_to_category.get(tag, "unknown")
68
- thresh = category_thresholds.get(cat, default_threshold)
 
69
  if cat == "artist":
70
  all_artist_tags.append((tag, float(prob)))
71
  if float(prob) >= thresh:
@@ -80,7 +83,6 @@ def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
80
 
81
  Returns a comma-separated string of escaped tags.
82
  """
83
- # Sort tags within each category by probability (descending)
84
  for cat in prompt_tags_by_cat:
85
  prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
86
 
@@ -89,7 +91,7 @@ def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
89
  general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
90
  prompt_tags = artist_tags + character_tags + general_tags
91
 
92
- # Ensure at least one artist tag appears if available, even if below threshold
93
  if not artist_tags and all_artist_tags:
94
  best_artist_tag, _ = max(all_artist_tags, key=lambda item: item[1])
95
  prompt_tags.insert(0, escape_tag(best_artist_tag))
@@ -115,16 +117,20 @@ def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
115
  lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
116
  for tag, prob in tag_list:
117
  lines.append(f"- {escape_tag(tag)} (Prob: {prob:.3f})")
118
- lines.append("") # blank line between categories
119
  return "\n".join(lines)
120
 
121
- def tag_image(pil_image: Image.Image, output_format: str) -> str:
122
- """Run inference on the image and return formatted tags based on the chosen output format."""
 
 
 
 
123
  if pil_image is None:
124
  return "Please upload an image."
125
 
126
  refined_logits = run_inference(pil_image)
127
- results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(refined_logits, metadata)
128
 
129
  if output_format == "Prompt-style Tags":
130
  return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
@@ -152,11 +158,23 @@ with demo:
152
  value="Prompt-style Tags",
153
  label="Output Format"
154
  )
 
 
 
 
 
 
 
 
155
  tag_button = gr.Button("🔍 Tag Image")
156
  with gr.Column():
157
- output_box = gr.Markdown("") # Markdown output for formatted results
158
 
159
- tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
 
 
 
 
160
 
161
  gr.Markdown(
162
  "----\n"
 
44
  _, refined_logits = session.run(None, {input_name: input_tensor})
45
  return refined_logits[0]
46
 
47
+ def get_tags(refined_logits: np.ndarray, metadata: dict, custom_threshold: float = None):
48
  """
49
  Compute probabilities from logits and collect tag predictions.
50
 
51
+ If custom_threshold is provided, it overrides category-specific thresholds.
52
+
53
  Returns:
54
  results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
55
+ prompt_tags_by_cat: Dictionary for prompt-style output with keys: artist, character, general.
56
  all_artist_tags: All artist tags (with probabilities) regardless of threshold.
57
  """
58
  probs = 1 / (1 + np.exp(-refined_logits))
 
67
  for idx, prob in enumerate(probs):
68
  tag = idx_to_tag[str(idx)]
69
  cat = tag_to_category.get(tag, "unknown")
70
+ # Use custom threshold if provided; otherwise, use metadata threshold or default.
71
+ thresh = custom_threshold if custom_threshold is not None else category_thresholds.get(cat, DEFAULT_THRESHOLD)
72
  if cat == "artist":
73
  all_artist_tags.append((tag, float(prob)))
74
  if float(prob) >= thresh:
 
83
 
84
  Returns a comma-separated string of escaped tags.
85
  """
 
86
  for cat in prompt_tags_by_cat:
87
  prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
88
 
 
91
  general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
92
  prompt_tags = artist_tags + character_tags + general_tags
93
 
94
+ # Ensure at least one artist tag appears even if none pass the threshold
95
  if not artist_tags and all_artist_tags:
96
  best_artist_tag, _ = max(all_artist_tags, key=lambda item: item[1])
97
  prompt_tags.insert(0, escape_tag(best_artist_tag))
 
117
  lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
118
  for tag, prob in tag_list:
119
  lines.append(f"- {escape_tag(tag)} (Prob: {prob:.3f})")
120
+ lines.append("")
121
  return "\n".join(lines)
122
 
123
+ def tag_image(pil_image: Image.Image, output_format: str, threshold: float) -> str:
124
+ """
125
+ Run inference on the image and return formatted tags based on the chosen output format.
126
+
127
+ The threshold slider value overrides category-specific thresholds if provided.
128
+ """
129
  if pil_image is None:
130
  return "Please upload an image."
131
 
132
  refined_logits = run_inference(pil_image)
133
+ results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(refined_logits, metadata, custom_threshold=threshold)
134
 
135
  if output_format == "Prompt-style Tags":
136
  return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
 
158
  value="Prompt-style Tags",
159
  label="Output Format"
160
  )
161
+ # Slider to modify the global threshold value
162
+ threshold_slider = gr.Slider(
163
+ minimum=0,
164
+ maximum=1,
165
+ step=0.05,
166
+ value=DEFAULT_THRESHOLD,
167
+ label="Global Threshold"
168
+ )
169
  tag_button = gr.Button("🔍 Tag Image")
170
  with gr.Column():
171
+ output_box = gr.Markdown("")
172
 
173
+ tag_button.click(
174
+ fn=tag_image,
175
+ inputs=[image_in, format_choice, threshold_slider],
176
+ outputs=output_box
177
+ )
178
 
179
  gr.Markdown(
180
  "----\n"