CyberWaifu commited on
Commit
5daec4f
·
verified ·
1 Parent(s): 11cfce1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -50,7 +50,7 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
50
 
51
  Returns:
52
  results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
53
- prompt_tags_by_cat: Dictionary for prompt-style output (artist, character, general).
54
  all_artist_tags: All artist tags (with probabilities) regardless of threshold.
55
  """
56
  probs = 1 / (1 + np.exp(-refined_logits))
@@ -59,7 +59,8 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
59
  category_thresholds = metadata.get("category_thresholds", {})
60
 
61
  results_by_cat = {}
62
- prompt_tags_by_cat = {"artist": [], "character": [], "general": []}
 
63
  all_artist_tags = []
64
 
65
  for idx, prob in enumerate(probs):
@@ -77,22 +78,29 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
77
  def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
78
  """
79
  Format the tags for prompt-style output.
 
80
 
81
  Returns a comma-separated string of escaped tags.
82
  """
83
- # Sort tags within each category by probability (descending)
 
 
 
 
 
 
84
  for cat in prompt_tags_by_cat:
85
  prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
86
 
87
- artist_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("artist", [])]
88
  character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])]
89
  general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
90
- prompt_tags = artist_tags + character_tags + general_tags
91
-
92
- # Ensure at least one artist tag appears if available, even if below threshold
93
- if not artist_tags and all_artist_tags:
94
- best_artist_tag, _ = max(all_artist_tags, key=lambda item: item[1])
95
- prompt_tags.insert(0, escape_tag(best_artist_tag))
 
96
  return ", ".join(prompt_tags) if prompt_tags else "No tags predicted."
97
 
98
  def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
@@ -145,8 +153,7 @@ with demo:
145
  "Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
146
  )
147
  gr.Markdown(
148
- "*(Note: The model predicts a large number of tags across categories like character, general, artist, etc. "
149
- "You can choose a concise prompt-style output or a detailed category-wise breakdown.)*"
150
  )
151
  with gr.Row():
152
  with gr.Column():
@@ -162,7 +169,7 @@ with demo:
162
  maximum=1.0,
163
  step=0.05,
164
  value=DEFAULT_THRESHOLD,
165
- label="Threshold"
166
  )
167
  tag_button = gr.Button("🔍 Tag Image")
168
  with gr.Column():
 
50
 
51
  Returns:
52
  results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
53
+ prompt_tags_by_cat: Dictionary for prompt-style output (character, general).
54
  all_artist_tags: All artist tags (with probabilities) regardless of threshold.
55
  """
56
  probs = 1 / (1 + np.exp(-refined_logits))
 
59
  category_thresholds = metadata.get("category_thresholds", {})
60
 
61
  results_by_cat = {}
62
+ # For prompt style, only include character and general tags (artists handled separately)
63
+ prompt_tags_by_cat = {"character": [], "general": []}
64
  all_artist_tags = []
65
 
66
  for idx, prob in enumerate(probs):
 
78
  def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
79
  """
80
  Format the tags for prompt-style output.
81
+ Only the top artist tag is shown (regardless of threshold), and all character and general tags are shown.
82
 
83
  Returns a comma-separated string of escaped tags.
84
  """
85
+ # Always select the best artist tag from all_artist_tags, regardless of threshold.
86
+ best_artist_tag = None
87
+ if all_artist_tags:
88
+ best_artist = max(all_artist_tags, key=lambda item: item[1])
89
+ best_artist_tag = escape_tag(best_artist[0])
90
+
91
+ # Sort character and general tags by probability (descending)
92
  for cat in prompt_tags_by_cat:
93
  prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
94
 
 
95
  character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])]
96
  general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
97
+
98
+ prompt_tags = []
99
+ if best_artist_tag:
100
+ prompt_tags.append(best_artist_tag)
101
+ prompt_tags.extend(character_tags)
102
+ prompt_tags.extend(general_tags)
103
+
104
  return ", ".join(prompt_tags) if prompt_tags else "No tags predicted."
105
 
106
  def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
 
153
  "Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
154
  )
155
  gr.Markdown(
156
+ "*(Note: In prompt-style output, only the top artist tag is displayed along with all character and general tags.)*"
 
157
  )
158
  with gr.Row():
159
  with gr.Column():
 
169
  maximum=1.0,
170
  step=0.05,
171
  value=DEFAULT_THRESHOLD,
172
+ label="Default Threshold"
173
  )
174
  tag_button = gr.Button("🔍 Tag Image")
175
  with gr.Column():