AngelBottomless's picture
Create app.py
f7165bd verified
raw
history blame contribute delete
3.49 kB
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download
# Load the ONNX model and metadata once at startup (optimizes performance)
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx" # using the smaller initial model for speed
META_FILE = "metadata.json"
# Download model and metadata from HF Hub (cache_dir="." will cache in the Space)
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
metadata = json.load(open(meta_path, "r", encoding="utf-8"))
# Preprocessing: resize image to 512x512 and normalize to match training
def preprocess_image(pil_image: Image.Image) -> np.ndarray:
img = pil_image.convert("RGB").resize((512, 512))
arr = np.array(img).astype(np.float32) / 255.0 # scale pixel values to [0,1]
arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
arr = np.expand_dims(arr, 0) # add batch dimension -> (1,3,512,512)
return arr
# Inference: run the ONNX model and collect tags above threshold
def predict_tags(pil_image: Image.Image) -> str:
# 1. Preprocess image to numpy
input_tensor = preprocess_image(pil_image)
# 2. Run model (both initial and refined logits are output)
input_name = session.get_inputs()[0].name
initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
# 3. Convert logits to probabilities (using sigmoid since multi-label)
probs = 1 / (1 + np.exp(-refined_logits)) # shape (1, 70527)
probs = probs[0] # remove batch dim -> (70527,)
# 4. Thresholding: get tag names for which probability >= category threshold (or default)
idx_to_tag = metadata["idx_to_tag"] # map index -> tag string
tag_to_category = metadata.get("tag_to_category", {}) # map tag -> category
category_thresholds = metadata.get("category_thresholds", {})# category-specific thresholds
default_threshold = 0.325
predicted_tags = []
for idx, prob in enumerate(probs):
tag = idx_to_tag[str(idx)]
cat = tag_to_category.get(tag, "unknown")
threshold = category_thresholds.get(cat, default_threshold)
if prob >= threshold:
# Include this tag; replace underscores with spaces for readability
predicted_tags.append(tag.replace("_", " "))
# 5. Return tags as comma-separated string
if not predicted_tags:
return "No tags found."
# Join tags, maybe sorted by name or leave unsorted. Here we sort alphabetically for consistency.
predicted_tags.sort()
return ", ".join(predicted_tags)
# Create a simple Gradio interface
demo = gr.Interface(
fn=predict_tags,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Textbox(label="Predicted Tags", lines=3),
title="Camie Tagger (ONNX) – Simple Demo",
description="Upload an anime/manga illustration to get relevant tags predicted by the Camie Tagger model.",
# You can optionally add example images if available in the Space directory:
examples=[["example1.jpg"], ["example2.png"]] # (filenames should exist in the Space)
)
# Launch the app (in HF Spaces, just calling demo.launch() is typically not required; the Space will run app automatically)
demo.launch()