import gradio as gr import onnxruntime as ort import numpy as np from PIL import Image import json from huggingface_hub import hf_hub_download # Constants MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime" MODEL_FILE = "camie_tagger_initial.onnx" META_FILE = "metadata.json" IMAGE_SIZE = (512, 512) DEFAULT_THRESHOLD = 0.35 # Default threshold if slider is used # Download model and metadata from Hugging Face Hub model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".") meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".") # Initialize ONNX Runtime session and load metadata session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) with open(meta_path, "r", encoding="utf-8") as f: metadata = json.load(f) def escape_tag(tag: str) -> str: """Escape underscores and parentheses for Markdown.""" return tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") def preprocess_image(pil_image: Image.Image) -> np.ndarray: """Convert image to RGB, resize, normalize, and rearrange dimensions.""" img = pil_image.convert("RGB").resize(IMAGE_SIZE) arr = np.array(img).astype(np.float32) / 255.0 arr = np.transpose(arr, (2, 0, 1)) return np.expand_dims(arr, 0) def run_inference(pil_image: Image.Image) -> np.ndarray: """ Preprocess the image and run the ONNX model inference. Returns the refined logits as a numpy array. """ input_tensor = preprocess_image(pil_image) input_name = session.get_inputs()[0].name # Only refined_logits are used (initial_logits is ignored) _, refined_logits = session.run(None, {input_name: input_tensor}) return refined_logits[0] def mcut_threshold(probs: np.ndarray) -> float: """ Compute the MCut threshold from the given probabilities. Uses the MCut method described in: Largeron, C., Moulin, C., & Gery, M. (2012). """ sorted_probs = probs[probs.argsort()[::-1]] diffs = sorted_probs[:-1] - sorted_probs[1:] t = diffs.argmax() thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 return thresh def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float): """ Compute probabilities from logits and collect tag predictions. Returns: results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold. prompt_tags_by_cat: Dictionary for prompt-style output (character and general tags). all_artist_tags: All artist tags (with probabilities) regardless of threshold. """ probs = 1 / (1 + np.exp(-refined_logits)) idx_to_tag = metadata["idx_to_tag"] tag_to_category = metadata.get("tag_to_category", {}) category_thresholds = metadata.get("category_thresholds", {}) results_by_cat = {} # For prompt-style output, only include character and general tags (artists handled separately) prompt_tags_by_cat = {"character": [], "general": []} all_artist_tags = [] for idx, prob in enumerate(probs): tag = idx_to_tag[str(idx)] cat = tag_to_category.get(tag, "unknown") thresh = category_thresholds.get(cat, default_threshold) if cat == "artist": all_artist_tags.append((tag, float(prob))) if float(prob) >= thresh: results_by_cat.setdefault(cat, []).append((tag, float(prob))) if cat in prompt_tags_by_cat: prompt_tags_by_cat[cat].append((tag, float(prob))) return results_by_cat, prompt_tags_by_cat, all_artist_tags def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str: """ Format the tags for prompt-style output. Only the top artist tag is shown (regardless of threshold), and all character and general tags are shown. Returns a comma-separated string of escaped tags. """ # Always select the best artist tag from all_artist_tags, regardless of threshold. best_artist_tag = None if all_artist_tags: best_artist = max(all_artist_tags, key=lambda item: item[1]) best_artist_tag = escape_tag(best_artist[0]) # Sort character and general tags by probability (descending) for cat in prompt_tags_by_cat: prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True) character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])] general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])] prompt_tags = [] if best_artist_tag: prompt_tags.append(best_artist_tag) prompt_tags.extend(character_tags) prompt_tags.extend(general_tags) return ", ".join(prompt_tags) if prompt_tags else "No tags predicted." def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str: """ Format the tags for detailed output. Returns a Markdown-formatted string listing tags by category. """ if not results_by_cat: return "No tags predicted for this image." # Include an artist tag even if below threshold. if "artist" not in results_by_cat and all_artist_tags: best_artist_tag, best_artist_prob = max(all_artist_tags, key=lambda item: item[1]) results_by_cat["artist"] = [(best_artist_tag, best_artist_prob)] lines = ["**Predicted Tags by Category:** \n"] for cat, tag_list in results_by_cat.items(): tag_list.sort(key=lambda x: x[1], reverse=True) lines.append(f"**Category: {cat}** – {len(tag_list)} tags") for tag, prob in tag_list: lines.append(f"- {escape_tag(tag)} (Prob: {prob:.3f})") lines.append("") # blank line between categories return "\n".join(lines) def tag_image(pil_image: Image.Image, output_format: str, threshold: float, mcut_enabled: bool) -> str: """ Run inference on the image and return formatted tags based on the chosen output format. The slider value (threshold) normally overrides the default threshold for tag selection. If mcut_enabled is True, compute a new threshold using MCut from all probabilities. """ if pil_image is None: return "Please upload an image." refined_logits = run_inference(pil_image) # Compute probabilities from logits probs = 1 / (1 + np.exp(-refined_logits)) # If MCut is enabled, override the threshold using the MCut method. computed_threshold = mcut_threshold(probs) if mcut_enabled else threshold results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags( refined_logits, metadata, default_threshold=computed_threshold ) if output_format == "Prompt-style Tags": return format_prompt_tags(prompt_tags_by_cat, all_artist_tags) else: return format_detailed_output(results_by_cat, all_artist_tags) # Build the Gradio Blocks UI demo = gr.Blocks(theme="gradio/soft") with demo: gr.Markdown( "# 🏷️ Camie Tagger – Anime Image Tagging\n" "This demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. " "Upload an image, adjust the threshold, and click **Tag Image** to see predictions." ) gr.Markdown( "*(Note: In prompt-style output, only the top artist tag is displayed along with all character and general tags. " "If MCut is enabled, its computed threshold overrides the default slider value.)*" ) with gr.Row(): with gr.Column(): image_in = gr.Image(type="pil", label="Input Image") format_choice = gr.Radio( choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format" ) threshold_slider = gr.Slider( minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_THRESHOLD, label="Default Threshold" ) mcut_checkbox = gr.Checkbox( value=False, label="Use MCut threshold" ) tag_button = gr.Button("🔍 Tag Image") with gr.Column(): output_box = gr.Markdown("") # Markdown output for formatted results # Pass the threshold_slider and mcut_checkbox values into the tag_image function tag_button.click( fn=tag_image, inputs=[image_in, format_choice, threshold_slider, mcut_checkbox], outputs=output_box ) gr.Markdown( "----\n" "**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • " "**Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • **ONNX Runtime:** for efficient CPU inference • " "*Demo built with Gradio Blocks.*" ) if __name__ == "__main__": demo.launch()