AngelBottomless's picture
Upload 9 files
a7ab59e verified
raw
history blame contribute delete
5.83 kB
import onnxruntime as ort
import numpy as np
import json
from PIL import Image
# 1) Load ONNX model
session = ort.InferenceSession("camie_tagger_initial_v15.onnx", providers=["CPUExecutionProvider"])
# 2) Preprocess your image (512x512, etc.)
def preprocess_image(img_path):
"""
Loads and resizes an image to 512x512, converts it to float32 [0..1],
and returns a (1,3,512,512) NumPy array (NCHW format).
"""
img = Image.open(img_path).convert("RGB").resize((512, 512))
x = np.array(img).astype(np.float32) / 255.0
x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
return x
# Example input
def load_thresholds(threshold_json_path, mode="balanced"):
"""
Loads thresholds from the given JSON file, using a particular mode
(e.g. 'balanced', 'high_precision', 'high_recall') for each category.
Returns:
thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }
fallback_threshold (float): The overall threshold if category not found
"""
with open(threshold_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
# The fallback threshold from the "overall" section for the chosen mode
fallback_threshold = data["overall"][mode]["threshold"]
# Build a dict of thresholds keyed by category
thresholds_by_category = {}
if "categories" in data:
for cat_name, cat_modes in data["categories"].items():
# If the chosen mode is present for that category, use it;
# otherwise fall back to the "overall" threshold.
if mode in cat_modes and "threshold" in cat_modes[mode]:
thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
else:
thresholds_by_category[cat_name] = fallback_threshold
return thresholds_by_category, fallback_threshold
def inference(
input_path,
output_format="verbose",
mode="balanced",
threshold_json_path="thresholds.json",
metadata_path="metadata.json"
):
"""
Run inference on an image using the loaded ONNX model, then apply
category-wise thresholds from `threshold.json` for the chosen mode.
Arguments:
input_path (str) : Path to the image file for inference.
output_format (str) : Either "verbose" or "as_prompt".
mode (str) : "balanced", "high_precision", or "high_recall"
threshold_json_path (str) : Path to the JSON file with category thresholds.
metadata_path (str) : Path to the metadata JSON file with category info.
Returns:
str: The predicted tags in either verbose or comma-separated format.
"""
# 1) Preprocess
input_tensor = preprocess_image(input_path)
# 2) Run inference
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input_tensor})
initial_logits, refined_logits = outputs # shape: (1, 70527) each
# 3) Convert logits to probabilities
refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527)
# 4) Load metadata & retrieve threshold info
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
tag_to_category = metadata.get("tag_to_category", {})
# Load thresholds from threshold.json using the specified mode
thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)
# 5) Collect predictions by category
results_by_category = {}
num_tags = refined_probs.shape[1]
for i in range(num_tags):
prob = float(refined_probs[0, i])
tag_name = idx_to_tag[str(i)] # str(i) because metadata uses string keys
category = tag_to_category.get(tag_name, "general")
# Determine the threshold to use for this category
cat_threshold = thresholds_by_category.get(category, fallback_threshold)
if prob >= cat_threshold:
if category not in results_by_category:
results_by_category[category] = []
results_by_category[category].append((tag_name, prob))
# 6) Depending on output_format, produce different return strings
if output_format == "as_prompt":
# Flatten all predicted tags across categories
all_predicted_tags = []
for cat, tags_list in results_by_category.items():
# We only need the tag name in as_prompt format
for tname, tprob in tags_list:
# convert underscores to spaces
tag_name_spaces = tname.replace("_", " ")
all_predicted_tags.append(tag_name_spaces)
# Create a comma-separated string
prompt_string = ", ".join(all_predicted_tags)
return prompt_string
else: # "verbose"
# We'll build a multiline string describing the predictions
lines = []
lines.append("Predicted Tags by Category:\n")
for cat, tags_list in results_by_category.items():
lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
# Sort descending by probability
for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}")
lines.append("") # blank line after each category
# Join lines with newlines
verbose_output = "\n".join(lines)
return verbose_output
if __name__ == "__main__":
result = inference("", output_format="as_prompt")
print(result)