|
import onnxruntime as ort
|
|
import numpy as np
|
|
import json
|
|
from PIL import Image
|
|
|
|
|
|
session = ort.InferenceSession("camie_tagger_initial_v15.onnx", providers=["CPUExecutionProvider"])
|
|
|
|
|
|
def preprocess_image(img_path):
|
|
"""
|
|
Loads and resizes an image to 512x512, converts it to float32 [0..1],
|
|
and returns a (1,3,512,512) NumPy array (NCHW format).
|
|
"""
|
|
img = Image.open(img_path).convert("RGB").resize((512, 512))
|
|
x = np.array(img).astype(np.float32) / 255.0
|
|
x = np.transpose(x, (2, 0, 1))
|
|
x = np.expand_dims(x, 0)
|
|
return x
|
|
|
|
|
|
def load_thresholds(threshold_json_path, mode="balanced"):
|
|
"""
|
|
Loads thresholds from the given JSON file, using a particular mode
|
|
(e.g. 'balanced', 'high_precision', 'high_recall') for each category.
|
|
|
|
Returns:
|
|
thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }
|
|
fallback_threshold (float): The overall threshold if category not found
|
|
"""
|
|
with open(threshold_json_path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
|
|
fallback_threshold = data["overall"][mode]["threshold"]
|
|
|
|
|
|
thresholds_by_category = {}
|
|
if "categories" in data:
|
|
for cat_name, cat_modes in data["categories"].items():
|
|
|
|
|
|
if mode in cat_modes and "threshold" in cat_modes[mode]:
|
|
thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
|
|
else:
|
|
thresholds_by_category[cat_name] = fallback_threshold
|
|
|
|
return thresholds_by_category, fallback_threshold
|
|
|
|
def inference(
|
|
input_path,
|
|
output_format="verbose",
|
|
mode="balanced",
|
|
threshold_json_path="thresholds.json",
|
|
metadata_path="metadata.json"
|
|
):
|
|
"""
|
|
Run inference on an image using the loaded ONNX model, then apply
|
|
category-wise thresholds from `threshold.json` for the chosen mode.
|
|
|
|
Arguments:
|
|
input_path (str) : Path to the image file for inference.
|
|
output_format (str) : Either "verbose" or "as_prompt".
|
|
mode (str) : "balanced", "high_precision", or "high_recall"
|
|
threshold_json_path (str) : Path to the JSON file with category thresholds.
|
|
metadata_path (str) : Path to the metadata JSON file with category info.
|
|
|
|
Returns:
|
|
str: The predicted tags in either verbose or comma-separated format.
|
|
"""
|
|
|
|
input_tensor = preprocess_image(input_path)
|
|
|
|
|
|
input_name = session.get_inputs()[0].name
|
|
outputs = session.run(None, {input_name: input_tensor})
|
|
initial_logits, refined_logits = outputs
|
|
|
|
|
|
refined_probs = 1 / (1 + np.exp(-refined_logits))
|
|
|
|
|
|
with open(metadata_path, "r", encoding="utf-8") as f:
|
|
metadata = json.load(f)
|
|
|
|
idx_to_tag = metadata["idx_to_tag"]
|
|
tag_to_category = metadata.get("tag_to_category", {})
|
|
|
|
thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)
|
|
|
|
|
|
|
|
results_by_category = {}
|
|
num_tags = refined_probs.shape[1]
|
|
|
|
for i in range(num_tags):
|
|
prob = float(refined_probs[0, i])
|
|
tag_name = idx_to_tag[str(i)]
|
|
category = tag_to_category.get(tag_name, "general")
|
|
|
|
|
|
cat_threshold = thresholds_by_category.get(category, fallback_threshold)
|
|
|
|
if prob >= cat_threshold:
|
|
if category not in results_by_category:
|
|
results_by_category[category] = []
|
|
results_by_category[category].append((tag_name, prob))
|
|
|
|
|
|
if output_format == "as_prompt":
|
|
|
|
all_predicted_tags = []
|
|
for cat, tags_list in results_by_category.items():
|
|
|
|
for tname, tprob in tags_list:
|
|
|
|
tag_name_spaces = tname.replace("_", " ")
|
|
all_predicted_tags.append(tag_name_spaces)
|
|
|
|
|
|
prompt_string = ", ".join(all_predicted_tags)
|
|
return prompt_string
|
|
|
|
else:
|
|
|
|
lines = []
|
|
lines.append("Predicted Tags by Category:\n")
|
|
for cat, tags_list in results_by_category.items():
|
|
lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
|
|
|
|
for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
|
|
lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}")
|
|
lines.append("")
|
|
|
|
verbose_output = "\n".join(lines)
|
|
return verbose_output
|
|
|
|
if __name__ == "__main__":
|
|
result = inference("", output_format="as_prompt")
|
|
print(result) |