CyberWaifu commited on
Commit
8da1280
·
verified ·
1 Parent(s): 598cad3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -89
app.py CHANGED
@@ -5,120 +5,157 @@ from PIL import Image
5
  import json
6
  from huggingface_hub import hf_hub_download
7
 
8
- # Load model and metadata at startup (same as before)
9
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
10
  MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
12
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
13
- meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
14
- session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
15
- metadata = json.load(open(meta_path, "r", encoding="utf-8"))
16
- # Preprocessing function (same as before)
17
- def preprocess_image(pil_image: Image.Image) -> np.ndarray:
18
- img = pil_image.convert("RGB").resize((512, 512))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  arr = np.array(img).astype(np.float32) / 255.0
20
  arr = np.transpose(arr, (2, 0, 1))
21
  arr = np.expand_dims(arr, 0)
22
  return arr
23
 
24
- # Inference function with output format option
25
- def tag_image(pil_image: Image.Image, output_format: str) -> str:
26
- # Run model inference
27
- input_tensor = preprocess_image(pil_image)
28
- input_name = session.get_inputs()[0].name
29
- initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
30
- probs = 1 / (1 + np.exp(-refined_logits))
31
- probs = probs[0]
32
  idx_to_tag = metadata["idx_to_tag"]
33
  tag_to_category = metadata.get("tag_to_category", {})
34
  category_thresholds = metadata.get("category_thresholds", {})
35
- default_threshold = 0.35
36
- results_by_cat = {} # to store tags per category (for verbose output)
37
- artist_tags_with_probs = []
38
- character_tags_with_probs = []
39
- general_tags_with_probs = []
40
- all_artist_tags_probs = [] # Store all artist tags and their probabilities
41
-
42
- # Collect tags above thresholds
43
  for idx, prob in enumerate(probs):
44
  tag = idx_to_tag[str(idx)]
45
  cat = tag_to_category.get(tag, "unknown")
46
  if cat == 'artist':
47
- all_artist_tags_probs.append((tag, float(prob))) # Store all artist tags
48
- thresh = category_thresholds.get(cat, default_threshold)
49
  if float(prob) >= thresh:
50
- # add to category dictionary
51
  results_by_cat.setdefault(cat, []).append((tag, float(prob)))
52
- if cat == 'artist':
53
- artist_tags_with_probs.append((tag, float(prob)))
54
- elif cat == 'character':
55
- character_tags_with_probs.append((tag, float(prob)))
56
- elif cat == 'general':
57
- general_tags_with_probs.append((tag, float(prob)))
58
 
59
- if output_format == "Prompt-style Tags":
60
- artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
61
- character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
62
- general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
63
 
64
- artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
65
- character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
66
- general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_with_probs]
 
 
67
 
68
- prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
 
 
69
 
70
- # Ensure at least one artist tag if any artist tags were predicted at all, even below threshold
71
- if not artist_prompt_tags and all_artist_tags_probs:
72
- best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
 
 
 
 
 
 
73
  prompt_tags = [best_artist_tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")] + prompt_tags
74
 
 
 
 
75
 
76
- if not prompt_tags:
77
- return "No tags predicted."
78
- return ", ".join(prompt_tags)
79
- else: # Detailed output
80
- if not results_by_cat:
81
- return "No tags predicted for this image."
82
 
83
- # Ensure artist tag in detailed output even if below threshold
84
- if 'artist' not in results_by_cat and all_artist_tags_probs:
85
- best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
86
  results_by_cat['artist'] = [(best_artist_tag, best_artist_prob)]
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- lines = []
90
- lines.append("**Predicted Tags by Category:** \n") # (Markdown newline: two spaces + newline)
91
- for cat, tag_list in results_by_cat.items():
92
- # sort tags in this category by probability descending
93
- tag_list.sort(key=lambda x: x[1], reverse=True)
94
- lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
95
- for tag, prob in tag_list:
96
- tag_pretty = tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") # Escape parentheses here with raw string
97
- lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
98
- lines.append("") # blank line between categories
99
- return "\n".join(lines)
100
-
101
- # Build the Gradio Blocks UI
102
- demo = gr.Blocks(theme="gradio/soft") # using a built-in theme for nicer styling
103
-
104
- with demo:
105
- # Header Section
106
- gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
107
- gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
108
- # Input/Output Section
109
- with gr.Row():
110
- # Left column: Image input and format selection
111
- with gr.Column():
112
- image_in = gr.Image(type="pil", label="Input Image")
113
- format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
114
- tag_button = gr.Button("🔍 Tag Image")
115
- # Right column: Output display
116
- with gr.Column():
117
- output_box = gr.Markdown("") # will display the result in Markdown (supports bold, lists, etc.)
118
- # Link the button click to the function
119
- tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
120
- # Footer/Info
121
- gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   •   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   •   **ONNX Runtime:** for efficient CPU inference​:contentReference[oaicite:6]{index=6}   •   *Demo built with Gradio Blocks.*")
122
-
123
- # Launch the app (automatically handled in Spaces)
124
- demo.launch()
 
5
  import json
6
  from huggingface_hub import hf_hub_download
7
 
8
+ # --- Constants ---
9
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
10
  MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
12
+ IMAGE_SIZE = (512, 512)
13
+ DEFAULT_THRESHOLD = 0.35
14
+
15
+
16
+ # --- Helper Functions ---
17
+ def download_model_and_metadata(repo_id: str, model_filename: str, meta_filename: str, cache_dir: str = "."):
18
+ """Downloads the ONNX model and metadata from Hugging Face Hub."""
19
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir=cache_dir)
20
+ meta_path = hf_hub_download(repo_id=repo_id, filename=meta_filename, cache_dir=cache_dir)
21
+ return model_path, meta_path
22
+
23
+ def load_model_session(model_path: str) -> ort.InferenceSession:
24
+ """Loads the ONNX model inference session."""
25
+ return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
26
+
27
+ def load_metadata(meta_path: str) -> dict:
28
+ """Loads the metadata from the JSON file."""
29
+ with open(meta_path, "r", encoding="utf-8") as f:
30
+ return json.load(f)
31
+
32
+ def preprocess_image(pil_image: Image.Image, image_size: tuple = IMAGE_SIZE) -> np.ndarray:
33
+ """Preprocesses the PIL image to numpy array for model input."""
34
+ img = pil_image.convert("RGB").resize(image_size)
35
  arr = np.array(img).astype(np.float32) / 255.0
36
  arr = np.transpose(arr, (2, 0, 1))
37
  arr = np.expand_dims(arr, 0)
38
  return arr
39
 
40
+ def apply_sigmoid(logits: np.ndarray) -> np.ndarray:
41
+ """Applies sigmoid function to logits to get probabilities."""
42
+ return 1 / (1 + np.exp(-logits))
43
+
44
+ def extract_tags_from_probabilities(probs: np.ndarray, metadata: dict, threshold: float = DEFAULT_THRESHOLD) -> dict:
45
+ """Extracts tags and probabilities from the model output probabilities."""
 
 
46
  idx_to_tag = metadata["idx_to_tag"]
47
  tag_to_category = metadata.get("tag_to_category", {})
48
  category_thresholds = metadata.get("category_thresholds", {})
49
+ results_by_cat = {}
50
+ all_artist_tags_probs = []
51
+
 
 
 
 
 
52
  for idx, prob in enumerate(probs):
53
  tag = idx_to_tag[str(idx)]
54
  cat = tag_to_category.get(tag, "unknown")
55
  if cat == 'artist':
56
+ all_artist_tags_probs.append((tag, float(prob)))
57
+ thresh = category_thresholds.get(cat, threshold)
58
  if float(prob) >= thresh:
 
59
  results_by_cat.setdefault(cat, []).append((tag, float(prob)))
 
 
 
 
 
 
60
 
61
+ return results_by_cat, all_artist_tags_probs
 
 
 
62
 
63
+ def format_prompt_style_output(results_by_cat: dict, all_artist_tags_probs: list) -> str:
64
+ """Formats the output as a comma-separated prompt-style string."""
65
+ artist_tags_with_probs = results_by_cat.get('artist', [])
66
+ character_tags_with_probs = results_by_cat.get('character', [])
67
+ general_tags_with_probs = results_by_cat.get('general', [])
68
 
69
+ artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
70
+ character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
71
+ general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
72
 
73
+ artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
74
+ character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
75
+ general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_with_probs]
76
+
77
+ prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
78
+
79
+ if not artist_prompt_tags and all_artist_tags_probs:
80
+ best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1]) if all_artist_tags_probs else (None, None)
81
+ if best_artist_tag: # Check if best_artist_tag is not None
82
  prompt_tags = [best_artist_tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")] + prompt_tags
83
 
84
+ if not prompt_tags:
85
+ return "No tags predicted."
86
+ return ", ".join(prompt_tags)
87
 
88
+ def format_detailed_output(results_by_cat: dict, all_artist_tags_probs: list) -> str:
89
+ """Formats the output as a detailed markdown string with categories and probabilities."""
90
+ if not results_by_cat:
91
+ return "No tags predicted for this image."
 
 
92
 
93
+ if 'artist' not in results_by_cat and all_artist_tags_probs:
94
+ best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1]) if all_artist_tags_probs else (None, None)
95
+ if best_artist_tag: # Check if best_artist_tag is not None
96
  results_by_cat['artist'] = [(best_artist_tag, best_artist_prob)]
97
 
98
+ lines = []
99
+ lines.append("**Predicted Tags by Category:** \n")
100
+ for cat, tag_list in results_by_cat.items():
101
+ tag_list.sort(key=lambda x: x[1], reverse=True)
102
+ lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
103
+ for tag, prob in tag_list:
104
+ tag_pretty = tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")
105
+ lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
106
+ lines.append("")
107
+ return "\n".join(lines)
108
+
109
+ # --- Inference Function ---
110
+ def tag_image(pil_image: Image.Image, output_format: str, session: ort.InferenceSession, metadata: dict) -> str:
111
+ """Tags the image and formats the output based on the selected format."""
112
+ if pil_image is None:
113
+ return "Please upload an image."
114
+
115
+ input_tensor = preprocess_image(pil_image)
116
+ input_name = session.get_inputs()[0].name
117
+ initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
118
+ probs = apply_sigmoid(refined_logits)[0] # Apply sigmoid and get probabilities for the first (and only) image in batch
119
+
120
+ results_by_cat, all_artist_tags_probs = extract_tags_from_probabilities(probs, metadata)
121
+
122
+ if output_format == "Prompt-style Tags":
123
+ return format_prompt_style_output(results_by_cat, all_artist_tags_probs)
124
+ else: # Detailed Output
125
+ return format_detailed_output(results_by_cat, all_artist_tags_probs)
126
+
127
+
128
+ # --- Gradio UI ---
129
+ def create_gradio_interface(session: ort.InferenceSession, metadata: dict) -> gr.Blocks:
130
+ """Creates the Gradio Blocks interface."""
131
+ demo = gr.Blocks(theme="gradio/soft")
132
+
133
+ with demo:
134
+ gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
135
+ gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
136
+
137
+ with gr.Row():
138
+ with gr.Column():
139
+ image_in = gr.Image(type="pil", label="Input Image")
140
+ format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
141
+ tag_button = gr.Button("🔍 Tag Image")
142
+ with gr.Column():
143
+ output_box = gr.Markdown("")
144
+
145
+ tag_button.click(
146
+ fn=tag_image,
147
+ inputs=[image_in, format_choice],
148
+ outputs=output_box,
149
+ extra_args=[session, metadata] # Pass session and metadata as extra arguments
150
+ )
151
+
152
+ gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   •   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   •   **ONNX Runtime:** for efficient CPU inference   •   *Demo built with Gradio Blocks.*")
153
+ return demo
154
 
155
+ # --- Main Script ---
156
+ if __name__ == "__main__":
157
+ model_path, meta_path = download_model_and_metadata(MODEL_REPO, MODEL_FILE, META_FILE)
158
+ session = load_model_session(model_path)
159
+ metadata = load_metadata(meta_path)
160
+ demo = create_gradio_interface(session, metadata)
161
+ demo.launch()