CyberWaifu commited on
Commit
b403fe7
·
verified ·
1 Parent(s): 6e46e05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -129
app.py CHANGED
@@ -4,161 +4,121 @@ import numpy as np
4
  from PIL import Image
5
  import json
6
  from huggingface_hub import hf_hub_download
7
- from functools import partial # Import partial
8
 
9
- # --- Constants ---
10
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
11
  MODEL_FILE = "camie_tagger_initial.onnx"
12
  META_FILE = "metadata.json"
13
- IMAGE_SIZE = (512, 512)
14
- DEFAULT_THRESHOLD = 0.35
15
-
16
-
17
- # --- Helper Functions ---
18
- def download_model_and_metadata(repo_id: str, model_filename: str, meta_filename: str, cache_dir: str = "."):
19
- """Downloads the ONNX model and metadata from Hugging Face Hub."""
20
- model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir=cache_dir)
21
- meta_path = hf_hub_download(repo_id=repo_id, filename=meta_filename, cache_dir=cache_dir)
22
- return model_path, meta_path
23
-
24
- def load_model_session(model_path: str) -> ort.InferenceSession:
25
- """Loads the ONNX model inference session."""
26
- return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
27
-
28
- def load_metadata(meta_path: str) -> dict:
29
- """Loads the metadata from the JSON file."""
30
- with open(meta_path, "r", encoding="utf-8") as f:
31
- return json.load(f)
32
-
33
- def preprocess_image(pil_image: Image.Image, image_size: tuple = IMAGE_SIZE) -> np.ndarray:
34
- """Preprocesses the PIL image to numpy array for model input."""
35
- img = pil_image.convert("RGB").resize(image_size)
36
  arr = np.array(img).astype(np.float32) / 255.0
37
  arr = np.transpose(arr, (2, 0, 1))
38
  arr = np.expand_dims(arr, 0)
39
  return arr
40
 
41
- def apply_sigmoid(logits: np.ndarray) -> np.ndarray:
42
- """Applies sigmoid function to logits to get probabilities."""
43
- return 1 / (1 + np.exp(-logits))
44
-
45
- def extract_tags_from_probabilities(probs: np.ndarray, metadata: dict, threshold: float = DEFAULT_THRESHOLD) -> dict:
46
- """Extracts tags and probabilities from the model output probabilities."""
 
 
47
  idx_to_tag = metadata["idx_to_tag"]
48
  tag_to_category = metadata.get("tag_to_category", {})
49
  category_thresholds = metadata.get("category_thresholds", {})
50
- results_by_cat = {}
51
- all_artist_tags_probs = []
52
-
 
 
 
 
 
53
  for idx, prob in enumerate(probs):
54
  tag = idx_to_tag[str(idx)]
55
  cat = tag_to_category.get(tag, "unknown")
56
  if cat == 'artist':
57
- all_artist_tags_probs.append((tag, float(prob)))
58
- thresh = category_thresholds.get(cat, threshold)
59
  if float(prob) >= thresh:
 
60
  results_by_cat.setdefault(cat, []).append((tag, float(prob)))
 
 
 
 
 
 
61
 
62
- return results_by_cat, all_artist_tags_probs
63
-
64
- def format_prompt_style_output(results_by_cat: dict, all_artist_tags_probs: list) -> str:
65
- """Formats the output as a comma-separated prompt-style string."""
66
- artist_tags_with_probs = results_by_cat.get('artist', [])
67
- character_tags_with_probs = results_by_cat.get('character', [])
68
- general_tags_with_probs = results_by_cat.get('general', [])
69
-
70
- artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
71
- character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
72
- general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
73
 
74
- artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
75
- character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
76
- general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_probs]
77
 
78
- prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
79
 
80
- if not artist_prompt_tags and all_artist_tags_probs:
81
- best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1]) if all_artist_tags_probs else (None, None)
82
- if best_artist_tag: # Check if best_artist_tag is not None
83
  prompt_tags = [best_artist_tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")] + prompt_tags
84
 
85
- if not prompt_tags:
86
- return "No tags predicted."
87
- return ", ".join(prompt_tags)
88
 
89
- def format_detailed_output(results_by_cat: dict, all_artist_tags_probs: list) -> str:
90
- """Formats the output as a detailed markdown string with categories and probabilities."""
91
- if not results_by_cat:
92
- return "No tags predicted for this image."
 
 
93
 
94
- if 'artist' not in results_by_cat and all_artist_tags_probs:
95
- best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1]) if all_artist_tags_probs else (None, None)
96
- if best_artist_tag: # Check if best_artist_tag is not None
97
  results_by_cat['artist'] = [(best_artist_tag, best_artist_prob)]
98
 
99
- lines = []
100
- lines.append("**Predicted Tags by Category:** \n")
101
- for cat, tag_list in results_by_cat.items():
102
- tag_list.sort(key=lambda x: x[1], reverse=True)
103
- lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
104
- for tag, prob in tag_list:
105
- tag_pretty = tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")
106
- lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
107
- lines.append("")
108
- return "\n".join(lines)
109
-
110
- # --- Inference Function ---
111
- def tag_image(pil_image: Image.Image, output_format: str, session: ort.InferenceSession, metadata: dict) -> str:
112
- """Tags the image and formats the output based on the selected format."""
113
- if pil_image is None:
114
- return "Please upload an image."
115
-
116
- input_tensor = preprocess_image(pil_image)
117
- input_name = session.get_inputs()[0].name
118
- initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
119
- probs = apply_sigmoid(refined_logits)[0] # Apply sigmoid and get probabilities for the first (and only) image in batch
120
-
121
- results_by_cat, all_artist_tags_probs = extract_tags_from_probabilities(probs, metadata)
122
-
123
- if output_format == "Prompt-style Tags":
124
- return format_prompt_style_output(results_by_cat, all_artist_tags_probs)
125
- else: # Detailed Output
126
- return format_detailed_output(results_by_cat, all_artist_tags_probs)
127
-
128
-
129
- # --- Gradio UI ---
130
- def create_gradio_interface(session: ort.InferenceSession, metadata: dict) -> gr.Blocks:
131
- """Creates the Gradio Blocks interface."""
132
- demo = gr.Blocks(theme="gradio/soft")
133
-
134
- with demo:
135
- gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
136
- gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
137
-
138
- with gr.Row():
139
- with gr.Column():
140
- image_in = gr.Image(type="pil", label="Input Image")
141
- format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
142
- tag_button = gr.Button("🔍 Tag Image")
143
- with gr.Column():
144
- output_box = gr.Markdown("")
145
-
146
- # Create a partial function with session and metadata pre-filled
147
- tag_image_with_model = partial(tag_image, session=session, metadata=metadata)
148
-
149
- tag_button.click(
150
- fn=tag_image_with_model, # Use the partially applied function
151
- inputs=[image_in, format_choice],
152
- outputs=output_box,
153
- )
154
-
155
- gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   •   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   •   **ONNX Runtime:** for efficient CPU inference   •   *Demo built with Gradio Blocks.*")
156
- return demo
157
 
158
- # --- Main Script ---
159
- if __name__ == "__main__":
160
- model_path, meta_path = download_model_and_metadata(MODEL_REPO, MODEL_FILE, META_FILE)
161
- session = load_model_session(model_path)
162
- metadata = load_metadata(meta_path)
163
- demo = create_gradio_interface(session, metadata)
164
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from PIL import Image
5
  import json
6
  from huggingface_hub import hf_hub_download
 
7
 
8
+ # Load model and metadata at startup (same as before)
9
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
10
  MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
12
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
13
+ meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
14
+ session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
15
+ metadata = json.load(open(meta_path, "r", encoding="utf-8"))
16
+ # Preprocessing function (same as before)
17
+ def preprocess_image(pil_image: Image.Image) -> np.ndarray:
18
+ img = pil_image.convert("RGB").resize((512, 512))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  arr = np.array(img).astype(np.float32) / 255.0
20
  arr = np.transpose(arr, (2, 0, 1))
21
  arr = np.expand_dims(arr, 0)
22
  return arr
23
 
24
+ # Inference function with output format option
25
+ def tag_image(pil_image: Image.Image, output_format: str) -> str:
26
+ # Run model inference
27
+ input_tensor = preprocess_image(pil_image)
28
+ input_name = session.get_inputs()[0].name
29
+ initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
30
+ probs = 1 / (1 + np.exp(-refined_logits))
31
+ probs = probs[0]
32
  idx_to_tag = metadata["idx_to_tag"]
33
  tag_to_category = metadata.get("tag_to_category", {})
34
  category_thresholds = metadata.get("category_thresholds", {})
35
+ default_threshold = 0.35
36
+ results_by_cat = {} # to store tags per category (for verbose output)
37
+ artist_tags_with_probs = []
38
+ character_tags_with_probs = []
39
+ general_tags_with_probs = []
40
+ all_artist_tags_probs = [] # Store all artist tags and their probabilities
41
+
42
+ # Collect tags above thresholds
43
  for idx, prob in enumerate(probs):
44
  tag = idx_to_tag[str(idx)]
45
  cat = tag_to_category.get(tag, "unknown")
46
  if cat == 'artist':
47
+ all_artist_tags_probs.append((tag, float(prob))) # Store all artist tags
48
+ thresh = category_thresholds.get(cat, default_threshold)
49
  if float(prob) >= thresh:
50
+ # add to category dictionary
51
  results_by_cat.setdefault(cat, []).append((tag, float(prob)))
52
+ if cat == 'artist':
53
+ artist_tags_with_probs.append((tag, float(prob)))
54
+ elif cat == 'character':
55
+ character_tags_with_probs.append((tag, float(prob)))
56
+ elif cat == 'general':
57
+ general_tags_with_probs.append((tag, float(prob)))
58
 
59
+ if output_format == "Prompt-style Tags":
60
+ artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
61
+ character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
62
+ general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
63
 
64
+ artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
65
+ character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
66
+ general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_with_probs]
67
 
68
+ prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
69
 
70
+ # Ensure at least one artist tag if any artist tags were predicted at all, even below threshold
71
+ if not artist_prompt_tags and all_artist_tags_probs:
72
+ best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
73
  prompt_tags = [best_artist_tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")] + prompt_tags
74
 
 
 
 
75
 
76
+ if not prompt_tags:
77
+ return "No tags predicted."
78
+ return ", ".join(prompt_tags)
79
+ else: # Detailed output
80
+ if not results_by_cat:
81
+ return "No tags predicted for this image."
82
 
83
+ # Ensure artist tag in detailed output even if below threshold
84
+ if 'artist' not in results_by_cat and all_artist_tags_probs:
85
+ best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
86
  results_by_cat['artist'] = [(best_artist_tag, best_artist_prob)]
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ lines = []
90
+ lines.append("**Predicted Tags by Category:** \n") # (Markdown newline: two spaces + newline)
91
+ for cat, tag_list in results_by_cat.items():
92
+ # sort tags in this category by probability descending
93
+ tag_list.sort(key=lambda x: x[1], reverse=True)
94
+ lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
95
+ for tag, prob in tag_list:
96
+ tag_pretty = tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") # Escape parentheses here with raw string
97
+ lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
98
+ lines.append("") # blank line between categories
99
+ return "\n".join(lines)
100
+
101
+ # Build the Gradio Blocks UI
102
+ demo = gr.Blocks(theme="gradio/soft") # using a built-in theme for nicer styling
103
+
104
+ with demo:
105
+ # Header Section
106
+ gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
107
+ gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
108
+ # Input/Output Section
109
+ with gr.Row():
110
+ # Left column: Image input and format selection
111
+ with gr.Column():
112
+ image_in = gr.Image(type="pil", label="Input Image")
113
+ format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
114
+ tag_button = gr.Button("🔍 Tag Image")
115
+ # Right column: Output display
116
+ with gr.Column():
117
+ output_box = gr.Markdown("") # will display the result in Markdown (supports bold, lists, etc.)
118
+ # Link the button click to the function
119
+ tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
120
+ # Footer/Info
121
+ gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • **ONNX Runtime:** for efficient CPU inference​:contentReference[oaicite:6]{index=6} • *Demo built with Gradio Blocks.*")
122
+
123
+ # Launch the app (automatically handled in Spaces)
124
+ demo.launch()