Spaces:
Running
Running
MCut
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
|
|
10 |
MODEL_FILE = "camie_tagger_initial.onnx"
|
11 |
META_FILE = "metadata.json"
|
12 |
IMAGE_SIZE = (512, 512)
|
13 |
-
DEFAULT_THRESHOLD = 0.35 # Default
|
14 |
|
15 |
# Download model and metadata from Hugging Face Hub
|
16 |
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
|
@@ -35,7 +35,6 @@ def preprocess_image(pil_image: Image.Image) -> np.ndarray:
|
|
35 |
def run_inference(pil_image: Image.Image) -> np.ndarray:
|
36 |
"""
|
37 |
Preprocess the image and run the ONNX model inference.
|
38 |
-
|
39 |
Returns the refined logits as a numpy array.
|
40 |
"""
|
41 |
input_tensor = preprocess_image(pil_image)
|
@@ -44,13 +43,26 @@ def run_inference(pil_image: Image.Image) -> np.ndarray:
|
|
44 |
_, refined_logits = session.run(None, {input_name: input_tensor})
|
45 |
return refined_logits[0]
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float):
|
48 |
"""
|
49 |
Compute probabilities from logits and collect tag predictions.
|
50 |
|
51 |
Returns:
|
52 |
-
results_by_cat: Dictionary mapping each category to a list of (tag, probability)
|
53 |
-
|
|
|
54 |
all_artist_tags: All artist tags (with probabilities) regardless of threshold.
|
55 |
"""
|
56 |
probs = 1 / (1 + np.exp(-refined_logits))
|
@@ -59,7 +71,7 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
|
|
59 |
category_thresholds = metadata.get("category_thresholds", {})
|
60 |
|
61 |
results_by_cat = {}
|
62 |
-
# For prompt
|
63 |
prompt_tags_by_cat = {"character": [], "general": []}
|
64 |
all_artist_tags = []
|
65 |
|
@@ -78,7 +90,8 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
|
|
78 |
def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
|
79 |
"""
|
80 |
Format the tags for prompt-style output.
|
81 |
-
Only the top artist tag is shown (regardless of threshold),
|
|
|
82 |
|
83 |
Returns a comma-separated string of escaped tags.
|
84 |
"""
|
@@ -106,13 +119,12 @@ def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
|
|
106 |
def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
|
107 |
"""
|
108 |
Format the tags for detailed output.
|
109 |
-
|
110 |
Returns a Markdown-formatted string listing tags by category.
|
111 |
"""
|
112 |
if not results_by_cat:
|
113 |
return "No tags predicted for this image."
|
114 |
|
115 |
-
# Include an artist tag even if below threshold
|
116 |
if "artist" not in results_by_cat and all_artist_tags:
|
117 |
best_artist_tag, best_artist_prob = max(all_artist_tags, key=lambda item: item[1])
|
118 |
results_by_cat["artist"] = [(best_artist_tag, best_artist_prob)]
|
@@ -126,17 +138,24 @@ def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
|
|
126 |
lines.append("") # blank line between categories
|
127 |
return "\n".join(lines)
|
128 |
|
129 |
-
def tag_image(pil_image: Image.Image, output_format: str, threshold: float) -> str:
|
130 |
"""
|
131 |
Run inference on the image and return formatted tags based on the chosen output format.
|
132 |
-
|
133 |
-
|
134 |
"""
|
135 |
if pil_image is None:
|
136 |
return "Please upload an image."
|
137 |
|
138 |
refined_logits = run_inference(pil_image)
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
if output_format == "Prompt-style Tags":
|
142 |
return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
|
@@ -153,7 +172,8 @@ with demo:
|
|
153 |
"Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
|
154 |
)
|
155 |
gr.Markdown(
|
156 |
-
"*(Note: In prompt-style output, only the top artist tag is displayed along with all character and general tags.
|
|
|
157 |
)
|
158 |
with gr.Row():
|
159 |
with gr.Column():
|
@@ -163,26 +183,32 @@ with demo:
|
|
163 |
value="Prompt-style Tags",
|
164 |
label="Output Format"
|
165 |
)
|
166 |
-
# Slider to modify the default threshold value used in inference.
|
167 |
threshold_slider = gr.Slider(
|
168 |
minimum=0.0,
|
169 |
maximum=1.0,
|
170 |
step=0.05,
|
171 |
value=DEFAULT_THRESHOLD,
|
172 |
-
label="Threshold"
|
|
|
|
|
|
|
|
|
173 |
)
|
174 |
tag_button = gr.Button("🔍 Tag Image")
|
175 |
with gr.Column():
|
176 |
output_box = gr.Markdown("") # Markdown output for formatted results
|
177 |
|
178 |
-
# Pass the threshold_slider
|
179 |
-
tag_button.click(
|
|
|
|
|
|
|
|
|
180 |
|
181 |
gr.Markdown(
|
182 |
"----\n"
|
183 |
"**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • "
|
184 |
-
"**Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • "
|
185 |
-
"**ONNX Runtime:** for efficient CPU inference • "
|
186 |
"*Demo built with Gradio Blocks.*"
|
187 |
)
|
188 |
|
|
|
10 |
MODEL_FILE = "camie_tagger_initial.onnx"
|
11 |
META_FILE = "metadata.json"
|
12 |
IMAGE_SIZE = (512, 512)
|
13 |
+
DEFAULT_THRESHOLD = 0.35 # Default threshold if slider is used
|
14 |
|
15 |
# Download model and metadata from Hugging Face Hub
|
16 |
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
|
|
|
35 |
def run_inference(pil_image: Image.Image) -> np.ndarray:
|
36 |
"""
|
37 |
Preprocess the image and run the ONNX model inference.
|
|
|
38 |
Returns the refined logits as a numpy array.
|
39 |
"""
|
40 |
input_tensor = preprocess_image(pil_image)
|
|
|
43 |
_, refined_logits = session.run(None, {input_name: input_tensor})
|
44 |
return refined_logits[0]
|
45 |
|
46 |
+
def mcut_threshold(probs: np.ndarray) -> float:
|
47 |
+
"""
|
48 |
+
Compute the MCut threshold from the given probabilities.
|
49 |
+
Uses the MCut method described in:
|
50 |
+
Largeron, C., Moulin, C., & Gery, M. (2012).
|
51 |
+
"""
|
52 |
+
sorted_probs = probs[probs.argsort()[::-1]]
|
53 |
+
diffs = sorted_probs[:-1] - sorted_probs[1:]
|
54 |
+
t = diffs.argmax()
|
55 |
+
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
|
56 |
+
return thresh
|
57 |
+
|
58 |
def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float):
|
59 |
"""
|
60 |
Compute probabilities from logits and collect tag predictions.
|
61 |
|
62 |
Returns:
|
63 |
+
results_by_cat: Dictionary mapping each category to a list of (tag, probability)
|
64 |
+
above its threshold.
|
65 |
+
prompt_tags_by_cat: Dictionary for prompt-style output (character and general tags).
|
66 |
all_artist_tags: All artist tags (with probabilities) regardless of threshold.
|
67 |
"""
|
68 |
probs = 1 / (1 + np.exp(-refined_logits))
|
|
|
71 |
category_thresholds = metadata.get("category_thresholds", {})
|
72 |
|
73 |
results_by_cat = {}
|
74 |
+
# For prompt-style output, only include character and general tags (artists handled separately)
|
75 |
prompt_tags_by_cat = {"character": [], "general": []}
|
76 |
all_artist_tags = []
|
77 |
|
|
|
90 |
def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
|
91 |
"""
|
92 |
Format the tags for prompt-style output.
|
93 |
+
Only the top artist tag is shown (regardless of threshold),
|
94 |
+
and all character and general tags are shown.
|
95 |
|
96 |
Returns a comma-separated string of escaped tags.
|
97 |
"""
|
|
|
119 |
def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
|
120 |
"""
|
121 |
Format the tags for detailed output.
|
|
|
122 |
Returns a Markdown-formatted string listing tags by category.
|
123 |
"""
|
124 |
if not results_by_cat:
|
125 |
return "No tags predicted for this image."
|
126 |
|
127 |
+
# Include an artist tag even if below threshold.
|
128 |
if "artist" not in results_by_cat and all_artist_tags:
|
129 |
best_artist_tag, best_artist_prob = max(all_artist_tags, key=lambda item: item[1])
|
130 |
results_by_cat["artist"] = [(best_artist_tag, best_artist_prob)]
|
|
|
138 |
lines.append("") # blank line between categories
|
139 |
return "\n".join(lines)
|
140 |
|
141 |
+
def tag_image(pil_image: Image.Image, output_format: str, threshold: float, mcut_enabled: bool) -> str:
|
142 |
"""
|
143 |
Run inference on the image and return formatted tags based on the chosen output format.
|
144 |
+
The slider value (threshold) normally overrides the default threshold for tag selection.
|
145 |
+
If mcut_enabled is True, compute a new threshold using MCut from all probabilities.
|
146 |
"""
|
147 |
if pil_image is None:
|
148 |
return "Please upload an image."
|
149 |
|
150 |
refined_logits = run_inference(pil_image)
|
151 |
+
# Compute probabilities from logits
|
152 |
+
probs = 1 / (1 + np.exp(-refined_logits))
|
153 |
+
# If MCut is enabled, override the threshold using the MCut method.
|
154 |
+
computed_threshold = mcut_threshold(probs) if mcut_enabled else threshold
|
155 |
+
|
156 |
+
results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(
|
157 |
+
refined_logits, metadata, default_threshold=computed_threshold
|
158 |
+
)
|
159 |
|
160 |
if output_format == "Prompt-style Tags":
|
161 |
return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
|
|
|
172 |
"Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
|
173 |
)
|
174 |
gr.Markdown(
|
175 |
+
"*(Note: In prompt-style output, only the top artist tag is displayed along with all character and general tags. "
|
176 |
+
"If MCut is enabled, its computed threshold overrides the default slider value.)*"
|
177 |
)
|
178 |
with gr.Row():
|
179 |
with gr.Column():
|
|
|
183 |
value="Prompt-style Tags",
|
184 |
label="Output Format"
|
185 |
)
|
|
|
186 |
threshold_slider = gr.Slider(
|
187 |
minimum=0.0,
|
188 |
maximum=1.0,
|
189 |
step=0.05,
|
190 |
value=DEFAULT_THRESHOLD,
|
191 |
+
label="Default Threshold"
|
192 |
+
)
|
193 |
+
mcut_checkbox = gr.Checkbox(
|
194 |
+
value=False,
|
195 |
+
label="Use MCut threshold"
|
196 |
)
|
197 |
tag_button = gr.Button("🔍 Tag Image")
|
198 |
with gr.Column():
|
199 |
output_box = gr.Markdown("") # Markdown output for formatted results
|
200 |
|
201 |
+
# Pass the threshold_slider and mcut_checkbox values into the tag_image function
|
202 |
+
tag_button.click(
|
203 |
+
fn=tag_image,
|
204 |
+
inputs=[image_in, format_choice, threshold_slider, mcut_checkbox],
|
205 |
+
outputs=output_box
|
206 |
+
)
|
207 |
|
208 |
gr.Markdown(
|
209 |
"----\n"
|
210 |
"**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • "
|
211 |
+
"**Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • **ONNX Runtime:** for efficient CPU inference • "
|
|
|
212 |
"*Demo built with Gradio Blocks.*"
|
213 |
)
|
214 |
|