AngelBottomless's picture
Thanks to GPT
7ec5b17 verified
raw
history blame contribute delete
5.33 kB
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download
# Load model and metadata at startup (same as before)
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"
META_FILE = "metadata.json"
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
metadata = json.load(open(meta_path, "r", encoding="utf-8"))
# Preprocessing function (same as before)
def preprocess_image(pil_image: Image.Image) -> np.ndarray:
img = pil_image.convert("RGB").resize((512, 512))
arr = np.array(img).astype(np.float32) / 255.0
arr = np.transpose(arr, (2, 0, 1))
arr = np.expand_dims(arr, 0)
return arr
# Inference function with output format option
def tag_image(pil_image: Image.Image, output_format: str) -> str:
# Run model inference
input_tensor = preprocess_image(pil_image)
input_name = session.get_inputs()[0].name
initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
probs = 1 / (1 + np.exp(-refined_logits))
probs = probs[0]
idx_to_tag = metadata["idx_to_tag"]
tag_to_category = metadata.get("tag_to_category", {})
category_thresholds = metadata.get("category_thresholds", {})
default_threshold = 0.325
results_by_cat = {} # to store tags per category (for verbose output)
prompt_tags = [] # to store tags for prompt-style output
# Collect tags above thresholds
for idx, prob in enumerate(probs):
tag = idx_to_tag[str(idx)]
cat = tag_to_category.get(tag, "unknown")
thresh = category_thresholds.get(cat, default_threshold)
if float(prob) >= thresh:
# add to category dictionary
results_by_cat.setdefault(cat, []).append((tag, float(prob)))
# add to prompt list
prompt_tags.append(tag.replace("_", " "))
if output_format == "Prompt-style Tags":
if not prompt_tags:
return "No tags predicted."
# Join tags with commas (sorted by probability for relevance)
# Sort prompt_tags by probability from results_by_cat (for better prompts ordering)
prompt_tags.sort(key=lambda t: max([p for (tg, p) in results_by_cat[tag_to_category.get(t.replace(' ', '_'), 'unknown')] if tg == t.replace(' ', '_')]), reverse=True)
return ", ".join(prompt_tags)
else: # Detailed output
if not results_by_cat:
return "No tags predicted for this image."
lines = []
lines.append("**Predicted Tags by Category:** \n") # (Markdown newline: two spaces + newline)
for cat, tag_list in results_by_cat.items():
# sort tags in this category by probability descending
tag_list.sort(key=lambda x: x[1], reverse=True)
lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
for tag, prob in tag_list:
tag_pretty = tag.replace("_", " ")
lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
lines.append("") # blank line between categories
return "\n".join(lines)
# Build the Gradio Blocks UI
demo = gr.Blocks(theme=gr.themes.Soft()) # using a built-in theme for nicer styling
with demo:
# Header Section
gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
# Input/Output Section
with gr.Row():
# Left column: Image input and format selection
with gr.Column():
image_in = gr.Image(type="pil", label="Input Image")
format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
tag_button = gr.Button("πŸ” Tag Image")
# Right column: Output display
with gr.Column():
output_box = gr.Markdown("") # will display the result in Markdown (supports bold, lists, etc.)
# Example images (if available in the repo)
gr.Examples(
examples=[["example1.jpg"], ["example2.png"]], # Example file paths (ensure these exist in the Space)
inputs=image_in,
outputs=output_box,
fn=tag_image,
cache_examples=True
)
# Link the button click to the function
tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
# Footer/Info
gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   β€’   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   β€’   **ONNX Runtime:** for efficient CPU inference​:contentReference[oaicite:6]{index=6}   β€’   *Demo built with Gradio Blocks.*")
# Launch the app (automatically handled in Spaces)
demo.launch()