Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -50,7 +50,7 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
|
|
50 |
|
51 |
Returns:
|
52 |
results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
|
53 |
-
prompt_tags_by_cat: Dictionary for prompt-style output (
|
54 |
all_artist_tags: All artist tags (with probabilities) regardless of threshold.
|
55 |
"""
|
56 |
probs = 1 / (1 + np.exp(-refined_logits))
|
@@ -59,7 +59,8 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
|
|
59 |
category_thresholds = metadata.get("category_thresholds", {})
|
60 |
|
61 |
results_by_cat = {}
|
62 |
-
|
|
|
63 |
all_artist_tags = []
|
64 |
|
65 |
for idx, prob in enumerate(probs):
|
@@ -77,22 +78,29 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
|
|
77 |
def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
|
78 |
"""
|
79 |
Format the tags for prompt-style output.
|
|
|
80 |
|
81 |
Returns a comma-separated string of escaped tags.
|
82 |
"""
|
83 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
for cat in prompt_tags_by_cat:
|
85 |
prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
|
86 |
|
87 |
-
artist_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("artist", [])]
|
88 |
character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])]
|
89 |
general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
96 |
return ", ".join(prompt_tags) if prompt_tags else "No tags predicted."
|
97 |
|
98 |
def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
|
@@ -145,8 +153,7 @@ with demo:
|
|
145 |
"Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
|
146 |
)
|
147 |
gr.Markdown(
|
148 |
-
"*(Note:
|
149 |
-
"You can choose a concise prompt-style output or a detailed category-wise breakdown.)*"
|
150 |
)
|
151 |
with gr.Row():
|
152 |
with gr.Column():
|
@@ -162,7 +169,7 @@ with demo:
|
|
162 |
maximum=1.0,
|
163 |
step=0.05,
|
164 |
value=DEFAULT_THRESHOLD,
|
165 |
-
label="Threshold"
|
166 |
)
|
167 |
tag_button = gr.Button("🔍 Tag Image")
|
168 |
with gr.Column():
|
|
|
50 |
|
51 |
Returns:
|
52 |
results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
|
53 |
+
prompt_tags_by_cat: Dictionary for prompt-style output (character, general).
|
54 |
all_artist_tags: All artist tags (with probabilities) regardless of threshold.
|
55 |
"""
|
56 |
probs = 1 / (1 + np.exp(-refined_logits))
|
|
|
59 |
category_thresholds = metadata.get("category_thresholds", {})
|
60 |
|
61 |
results_by_cat = {}
|
62 |
+
# For prompt style, only include character and general tags (artists handled separately)
|
63 |
+
prompt_tags_by_cat = {"character": [], "general": []}
|
64 |
all_artist_tags = []
|
65 |
|
66 |
for idx, prob in enumerate(probs):
|
|
|
78 |
def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
|
79 |
"""
|
80 |
Format the tags for prompt-style output.
|
81 |
+
Only the top artist tag is shown (regardless of threshold), and all character and general tags are shown.
|
82 |
|
83 |
Returns a comma-separated string of escaped tags.
|
84 |
"""
|
85 |
+
# Always select the best artist tag from all_artist_tags, regardless of threshold.
|
86 |
+
best_artist_tag = None
|
87 |
+
if all_artist_tags:
|
88 |
+
best_artist = max(all_artist_tags, key=lambda item: item[1])
|
89 |
+
best_artist_tag = escape_tag(best_artist[0])
|
90 |
+
|
91 |
+
# Sort character and general tags by probability (descending)
|
92 |
for cat in prompt_tags_by_cat:
|
93 |
prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
|
94 |
|
|
|
95 |
character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])]
|
96 |
general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
|
97 |
+
|
98 |
+
prompt_tags = []
|
99 |
+
if best_artist_tag:
|
100 |
+
prompt_tags.append(best_artist_tag)
|
101 |
+
prompt_tags.extend(character_tags)
|
102 |
+
prompt_tags.extend(general_tags)
|
103 |
+
|
104 |
return ", ".join(prompt_tags) if prompt_tags else "No tags predicted."
|
105 |
|
106 |
def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
|
|
|
153 |
"Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
|
154 |
)
|
155 |
gr.Markdown(
|
156 |
+
"*(Note: In prompt-style output, only the top artist tag is displayed along with all character and general tags.)*"
|
|
|
157 |
)
|
158 |
with gr.Row():
|
159 |
with gr.Column():
|
|
|
169 |
maximum=1.0,
|
170 |
step=0.05,
|
171 |
value=DEFAULT_THRESHOLD,
|
172 |
+
label="Default Threshold"
|
173 |
)
|
174 |
tag_button = gr.Button("🔍 Tag Image")
|
175 |
with gr.Column():
|