File size: 5,833 Bytes
a7ab59e ca9b012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import onnxruntime as ort
import numpy as np
import json
from PIL import Image
# 1) Load ONNX model
session = ort.InferenceSession("camie_tagger_initial_v15.onnx", providers=["CPUExecutionProvider"])
# 2) Preprocess your image (512x512, etc.)
def preprocess_image(img_path):
"""
Loads and resizes an image to 512x512, converts it to float32 [0..1],
and returns a (1,3,512,512) NumPy array (NCHW format).
"""
img = Image.open(img_path).convert("RGB").resize((512, 512))
x = np.array(img).astype(np.float32) / 255.0
x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
return x
# Example input
def load_thresholds(threshold_json_path, mode="balanced"):
"""
Loads thresholds from the given JSON file, using a particular mode
(e.g. 'balanced', 'high_precision', 'high_recall') for each category.
Returns:
thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }
fallback_threshold (float): The overall threshold if category not found
"""
with open(threshold_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
# The fallback threshold from the "overall" section for the chosen mode
fallback_threshold = data["overall"][mode]["threshold"]
# Build a dict of thresholds keyed by category
thresholds_by_category = {}
if "categories" in data:
for cat_name, cat_modes in data["categories"].items():
# If the chosen mode is present for that category, use it;
# otherwise fall back to the "overall" threshold.
if mode in cat_modes and "threshold" in cat_modes[mode]:
thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
else:
thresholds_by_category[cat_name] = fallback_threshold
return thresholds_by_category, fallback_threshold
def inference(
input_path,
output_format="verbose",
mode="balanced",
threshold_json_path="thresholds.json",
metadata_path="metadata.json"
):
"""
Run inference on an image using the loaded ONNX model, then apply
category-wise thresholds from `threshold.json` for the chosen mode.
Arguments:
input_path (str) : Path to the image file for inference.
output_format (str) : Either "verbose" or "as_prompt".
mode (str) : "balanced", "high_precision", or "high_recall"
threshold_json_path (str) : Path to the JSON file with category thresholds.
metadata_path (str) : Path to the metadata JSON file with category info.
Returns:
str: The predicted tags in either verbose or comma-separated format.
"""
# 1) Preprocess
input_tensor = preprocess_image(input_path)
# 2) Run inference
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input_tensor})
initial_logits, refined_logits = outputs # shape: (1, 70527) each
# 3) Convert logits to probabilities
refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527)
# 4) Load metadata & retrieve threshold info
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
tag_to_category = metadata.get("tag_to_category", {})
# Load thresholds from threshold.json using the specified mode
thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)
# 5) Collect predictions by category
results_by_category = {}
num_tags = refined_probs.shape[1]
for i in range(num_tags):
prob = float(refined_probs[0, i])
tag_name = idx_to_tag[str(i)] # str(i) because metadata uses string keys
category = tag_to_category.get(tag_name, "general")
# Determine the threshold to use for this category
cat_threshold = thresholds_by_category.get(category, fallback_threshold)
if prob >= cat_threshold:
if category not in results_by_category:
results_by_category[category] = []
results_by_category[category].append((tag_name, prob))
# 6) Depending on output_format, produce different return strings
if output_format == "as_prompt":
# Flatten all predicted tags across categories
all_predicted_tags = []
for cat, tags_list in results_by_category.items():
# We only need the tag name in as_prompt format
for tname, tprob in tags_list:
# convert underscores to spaces
tag_name_spaces = tname.replace("_", " ")
all_predicted_tags.append(tag_name_spaces)
# Create a comma-separated string
prompt_string = ", ".join(all_predicted_tags)
return prompt_string
else: # "verbose"
# We'll build a multiline string describing the predictions
lines = []
lines.append("Predicted Tags by Category:\n")
for cat, tags_list in results_by_category.items():
lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
# Sort descending by probability
for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}")
lines.append("") # blank line after each category
# Join lines with newlines
verbose_output = "\n".join(lines)
return verbose_output
if __name__ == "__main__":
result = inference("", output_format="as_prompt")
print(result) |