CyberWaifu commited on
Commit
b317da6
·
verified ·
1 Parent(s): b403fe7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -82
app.py CHANGED
@@ -5,120 +5,165 @@ from PIL import Image
5
  import json
6
  from huggingface_hub import hf_hub_download
7
 
8
- # Load model and metadata at startup (same as before)
9
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
10
  MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
 
 
 
 
12
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
13
  meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
 
 
14
  session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
15
- metadata = json.load(open(meta_path, "r", encoding="utf-8"))
16
- # Preprocessing function (same as before)
 
 
 
 
 
17
  def preprocess_image(pil_image: Image.Image) -> np.ndarray:
18
- img = pil_image.convert("RGB").resize((512, 512))
 
19
  arr = np.array(img).astype(np.float32) / 255.0
20
  arr = np.transpose(arr, (2, 0, 1))
21
- arr = np.expand_dims(arr, 0)
22
- return arr
23
 
24
- # Inference function with output format option
25
- def tag_image(pil_image: Image.Image, output_format: str) -> str:
26
- # Run model inference
 
 
 
27
  input_tensor = preprocess_image(pil_image)
28
  input_name = session.get_inputs()[0].name
29
- initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
 
 
 
 
 
 
 
 
 
 
 
 
30
  probs = 1 / (1 + np.exp(-refined_logits))
31
- probs = probs[0]
32
  idx_to_tag = metadata["idx_to_tag"]
33
  tag_to_category = metadata.get("tag_to_category", {})
34
  category_thresholds = metadata.get("category_thresholds", {})
35
- default_threshold = 0.35
36
- results_by_cat = {} # to store tags per category (for verbose output)
37
- artist_tags_with_probs = []
38
- character_tags_with_probs = []
39
- general_tags_with_probs = []
40
- all_artist_tags_probs = [] # Store all artist tags and their probabilities
41
-
42
- # Collect tags above thresholds
43
  for idx, prob in enumerate(probs):
44
  tag = idx_to_tag[str(idx)]
45
  cat = tag_to_category.get(tag, "unknown")
46
- if cat == 'artist':
47
- all_artist_tags_probs.append((tag, float(prob))) # Store all artist tags
48
  thresh = category_thresholds.get(cat, default_threshold)
 
 
49
  if float(prob) >= thresh:
50
- # add to category dictionary
51
  results_by_cat.setdefault(cat, []).append((tag, float(prob)))
52
- if cat == 'artist':
53
- artist_tags_with_probs.append((tag, float(prob)))
54
- elif cat == 'character':
55
- character_tags_with_probs.append((tag, float(prob)))
56
- elif cat == 'general':
57
- general_tags_with_probs.append((tag, float(prob)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if output_format == "Prompt-style Tags":
60
- artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
61
- character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
62
- general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
63
-
64
- artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
65
- character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
66
- general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_with_probs]
67
-
68
- prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
69
-
70
- # Ensure at least one artist tag if any artist tags were predicted at all, even below threshold
71
- if not artist_prompt_tags and all_artist_tags_probs:
72
- best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
73
- prompt_tags = [best_artist_tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")] + prompt_tags
74
-
75
-
76
- if not prompt_tags:
77
- return "No tags predicted."
78
- return ", ".join(prompt_tags)
79
- else: # Detailed output
80
- if not results_by_cat:
81
- return "No tags predicted for this image."
82
-
83
- # Ensure artist tag in detailed output even if below threshold
84
- if 'artist' not in results_by_cat and all_artist_tags_probs:
85
- best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
86
- results_by_cat['artist'] = [(best_artist_tag, best_artist_prob)]
87
-
88
-
89
- lines = []
90
- lines.append("**Predicted Tags by Category:** \n") # (Markdown newline: two spaces + newline)
91
- for cat, tag_list in results_by_cat.items():
92
- # sort tags in this category by probability descending
93
- tag_list.sort(key=lambda x: x[1], reverse=True)
94
- lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
95
- for tag, prob in tag_list:
96
- tag_pretty = tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") # Escape parentheses here with raw string
97
- lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
98
- lines.append("") # blank line between categories
99
- return "\n".join(lines)
100
 
101
  # Build the Gradio Blocks UI
102
- demo = gr.Blocks(theme="gradio/soft") # using a built-in theme for nicer styling
103
 
104
  with demo:
105
- # Header Section
106
- gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
107
- gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
108
- # Input/Output Section
 
 
 
 
 
109
  with gr.Row():
110
- # Left column: Image input and format selection
111
  with gr.Column():
112
  image_in = gr.Image(type="pil", label="Input Image")
113
- format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
 
 
 
 
114
  tag_button = gr.Button("🔍 Tag Image")
115
- # Right column: Output display
116
  with gr.Column():
117
- output_box = gr.Markdown("") # will display the result in Markdown (supports bold, lists, etc.)
118
- # Link the button click to the function
119
  tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
120
- # Footer/Info
121
- gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • **ONNX Runtime:** for efficient CPU inference​:contentReference[oaicite:6]{index=6} • *Demo built with Gradio Blocks.*")
 
 
 
 
 
122
 
123
- # Launch the app (automatically handled in Spaces)
124
- demo.launch()
 
5
  import json
6
  from huggingface_hub import hf_hub_download
7
 
8
+ # Constants
9
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
10
  MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
12
+ IMAGE_SIZE = (512, 512)
13
+ DEFAULT_THRESHOLD = 0.35
14
+
15
+ # Download model and metadata from Hugging Face Hub
16
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
17
  meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
18
+
19
+ # Initialize ONNX Runtime session and load metadata
20
  session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
21
+ with open(meta_path, "r", encoding="utf-8") as f:
22
+ metadata = json.load(f)
23
+
24
+ def escape_tag(tag: str) -> str:
25
+ """Escape underscores and parentheses for Markdown."""
26
+ return tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")
27
+
28
  def preprocess_image(pil_image: Image.Image) -> np.ndarray:
29
+ """Convert image to RGB, resize, normalize, and rearrange dimensions."""
30
+ img = pil_image.convert("RGB").resize(IMAGE_SIZE)
31
  arr = np.array(img).astype(np.float32) / 255.0
32
  arr = np.transpose(arr, (2, 0, 1))
33
+ return np.expand_dims(arr, 0)
 
34
 
35
+ def run_inference(pil_image: Image.Image) -> np.ndarray:
36
+ """
37
+ Preprocess the image and run the ONNX model inference.
38
+
39
+ Returns the refined logits as a numpy array.
40
+ """
41
  input_tensor = preprocess_image(pil_image)
42
  input_name = session.get_inputs()[0].name
43
+ # Only refined_logits are used (initial_logits is ignored)
44
+ _, refined_logits = session.run(None, {input_name: input_tensor})
45
+ return refined_logits[0]
46
+
47
+ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float = DEFAULT_THRESHOLD):
48
+ """
49
+ Compute probabilities from logits and collect tag predictions.
50
+
51
+ Returns:
52
+ results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
53
+ prompt_tags_by_cat: Similar dictionary but only for prompt-style categories (artist, character, general).
54
+ all_artist_tags: All artist tags (with probabilities) regardless of threshold.
55
+ """
56
  probs = 1 / (1 + np.exp(-refined_logits))
 
57
  idx_to_tag = metadata["idx_to_tag"]
58
  tag_to_category = metadata.get("tag_to_category", {})
59
  category_thresholds = metadata.get("category_thresholds", {})
60
+
61
+ results_by_cat = {}
62
+ prompt_tags_by_cat = {"artist": [], "character": [], "general": []}
63
+ all_artist_tags = []
64
+
 
 
 
65
  for idx, prob in enumerate(probs):
66
  tag = idx_to_tag[str(idx)]
67
  cat = tag_to_category.get(tag, "unknown")
 
 
68
  thresh = category_thresholds.get(cat, default_threshold)
69
+ if cat == "artist":
70
+ all_artist_tags.append((tag, float(prob)))
71
  if float(prob) >= thresh:
 
72
  results_by_cat.setdefault(cat, []).append((tag, float(prob)))
73
+ if cat in prompt_tags_by_cat:
74
+ prompt_tags_by_cat[cat].append((tag, float(prob)))
75
+ return results_by_cat, prompt_tags_by_cat, all_artist_tags
76
+
77
+ def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
78
+ """
79
+ Format the tags for prompt-style output.
80
+
81
+ Returns a comma-separated string of escaped tags.
82
+ """
83
+ # Sort tags within each category by probability (descending)
84
+ for cat in prompt_tags_by_cat:
85
+ prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
86
+
87
+ artist_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("artist", [])]
88
+ character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])]
89
+ general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
90
+ prompt_tags = artist_tags + character_tags + general_tags
91
+
92
+ # Ensure at least one artist tag appears if available, even if below threshold
93
+ if not artist_tags and all_artist_tags:
94
+ best_artist_tag, _ = max(all_artist_tags, key=lambda item: item[1])
95
+ prompt_tags.insert(0, escape_tag(best_artist_tag))
96
+ return ", ".join(prompt_tags) if prompt_tags else "No tags predicted."
97
+
98
+ def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
99
+ """
100
+ Format the tags for detailed output.
101
+
102
+ Returns a Markdown-formatted string listing tags by category.
103
+ """
104
+ if not results_by_cat:
105
+ return "No tags predicted for this image."
106
 
107
+ # Include an artist tag even if below threshold
108
+ if "artist" not in results_by_cat and all_artist_tags:
109
+ best_artist_tag, best_artist_prob = max(all_artist_tags, key=lambda item: item[1])
110
+ results_by_cat["artist"] = [(best_artist_tag, best_artist_prob)]
111
+
112
+ lines = ["**Predicted Tags by Category:** \n"]
113
+ for cat, tag_list in results_by_cat.items():
114
+ tag_list.sort(key=lambda x: x[1], reverse=True)
115
+ lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
116
+ for tag, prob in tag_list:
117
+ lines.append(f"- {escape_tag(tag)} (Prob: {prob:.3f})")
118
+ lines.append("") # blank line between categories
119
+ return "\n".join(lines)
120
+
121
+ def tag_image(pil_image: Image.Image, output_format: str) -> str:
122
+ """Run inference on the image and return formatted tags based on the chosen output format."""
123
+ if pil_image is None:
124
+ return "Please upload an image."
125
+
126
+ refined_logits = run_inference(pil_image)
127
+ results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(refined_logits, metadata)
128
+
129
  if output_format == "Prompt-style Tags":
130
+ return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
131
+ else:
132
+ return format_detailed_output(results_by_cat, all_artist_tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  # Build the Gradio Blocks UI
135
+ demo = gr.Blocks(theme="gradio/soft")
136
 
137
  with demo:
138
+ gr.Markdown(
139
+ "# 🏷️ Camie Tagger – Anime Image Tagging\n"
140
+ "This demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. "
141
+ "Upload an image and click **Tag Image** to see predictions."
142
+ )
143
+ gr.Markdown(
144
+ "*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. "
145
+ "You can choose a concise prompt-style output or a detailed category-wise breakdown.)*"
146
+ )
147
  with gr.Row():
 
148
  with gr.Column():
149
  image_in = gr.Image(type="pil", label="Input Image")
150
+ format_choice = gr.Radio(
151
+ choices=["Prompt-style Tags", "Detailed Output"],
152
+ value="Prompt-style Tags",
153
+ label="Output Format"
154
+ )
155
  tag_button = gr.Button("🔍 Tag Image")
 
156
  with gr.Column():
157
+ output_box = gr.Markdown("") # Markdown output for formatted results
158
+
159
  tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
160
+
161
+ gr.Markdown(
162
+ "----\n"
163
+ "**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • "
164
+ "**Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • **ONNX Runtime:** for efficient CPU inference • "
165
+ "*Demo built with Gradio Blocks.*"
166
+ )
167
 
168
+ if __name__ == "__main__":
169
+ demo.launch()