File size: 7,511 Bytes
bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e bc492d2 a7ab59e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import onnxruntime as ort
import numpy as np
import json
from PIL import Image
def preprocess_image(img_path, target_size=512, keep_aspect=True):
"""
Load an image from img_path, convert to RGB,
and resize/pad to (target_size, target_size).
Scales pixel values to [0,1] and returns a (1,3,target_size,target_size) float32 array.
"""
img = Image.open(img_path).convert("RGB")
if keep_aspect:
# Preserve aspect ratio, pad black
w, h = img.size
aspect = w / h
if aspect > 1:
new_w = target_size
new_h = int(new_w / aspect)
else:
new_h = target_size
new_w = int(new_h * aspect)
# Resize with Lanczos
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Pad to a square
background = Image.new("RGB", (target_size, target_size), (0, 0, 0))
paste_x = (target_size - new_w) // 2
paste_y = (target_size - new_h) // 2
background.paste(img, (paste_x, paste_y))
img = background
else:
# simple direct resize to 512x512
img = img.resize((target_size, target_size), Image.Resampling.LANCZOS)
# Convert to numpy array
arr = np.array(img).astype("float32") / 255.0 # scale to [0,1]
# Transpose from HWC -> CHW
arr = np.transpose(arr, (2, 0, 1))
# Add batch dimension: (1,3,512,512)
arr = np.expand_dims(arr, axis=0)
return arr
# Example input
def load_thresholds(threshold_json_path, mode="balanced"):
"""
Loads thresholds from the given JSON file, using a particular mode
(e.g. 'balanced', 'high_precision', 'high_recall') for each category.
Returns:
thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }
fallback_threshold (float): The overall threshold if category not found
"""
with open(threshold_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
# The fallback threshold from the "overall" section for the chosen mode
fallback_threshold = data["overall"][mode]["threshold"]
# Build a dict of thresholds keyed by category
thresholds_by_category = {}
if "categories" in data:
for cat_name, cat_modes in data["categories"].items():
# If the chosen mode is present for that category, use it;
# otherwise fall back to the "overall" threshold.
if mode in cat_modes and "threshold" in cat_modes[mode]:
thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
else:
thresholds_by_category[cat_name] = fallback_threshold
return thresholds_by_category, fallback_threshold
def onnx_inference(
img_paths,
onnx_path="camie_refined_no_flash.onnx",
metadata_file="metadata.json",
threshold_json_path="thresholds.json",
mode="balanced",
target_size=512,
keep_aspect=True
):
"""
Loads the ONNX model, runs inference on a list of image paths,
and applies category-wise thresholds from threshold.json (per the chosen mode).
Args:
img_paths : List of paths to images.
onnx_path : Path to the exported ONNX model file.
metadata_file : Path to metadata.json that contains idx_to_tag, tag_to_category, etc.
threshold_json_path : Path to thresholds.json containing category-wise threshold info.
mode : "balanced", "high_precision", or "high_recall".
target_size : Final size of preprocessed images (512 by default).
keep_aspect : If True, preserve aspect ratio when resizing, pad with black.
Returns:
A list of dicts, one per input image, each containing:
{
"initial_logits": np.ndarray of shape (N_tags,),
"refined_logits": np.ndarray of shape (N_tags,),
"predicted_indices": list of tag indices that exceeded threshold,
"predicted_tags": list of predicted tag strings,
...
}
"""
# 1) Initialize ONNX runtime session
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
# For GPU usage, you could do e.g.:
# session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
# 2) Pre-load metadata
with open(metadata_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
tag_to_category = metadata.get("tag_to_category", {})
# Load thresholds from thresholds.json using the specified mode
thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)
# 3) Preprocess each image into a batch
batch_tensors = []
for img_path in img_paths:
x = preprocess_image(img_path, target_size=target_size, keep_aspect=keep_aspect)
batch_tensors.append(x)
# Concatenate along the batch dimension => shape (batch_size, 3, H, W)
batch_input = np.concatenate(batch_tensors, axis=0)
# 4) Run inference
input_name = session.get_inputs()[0].name # typically "image" or "input"
outputs = session.run(None, {input_name: batch_input})
# Typically we get [initial_tags, refined_tags] as output
initial_preds, refined_preds = outputs # shapes => (batch_size, N_tags)
# 5) Convert logits -> probabilities -> apply category-specific thresholds
batch_results = []
for i in range(initial_preds.shape[0]):
init_logit = initial_preds[i, :] # shape (N_tags,)
ref_logit = refined_preds[i, :] # shape (N_tags,)
ref_prob = 1.0 / (1.0 + np.exp(-ref_logit)) # shape (N_tags,)
predicted_indices = []
predicted_tags = []
# Check each tag against the category threshold
for idx in range(ref_logit.shape[0]):
tag_name = idx_to_tag[str(idx)] # Convert index->string->tag name
category = tag_to_category.get(tag_name, "general") # fallback to "general" if missing
cat_threshold = thresholds_by_category.get(category, fallback_threshold)
if ref_prob[idx] >= cat_threshold:
predicted_indices.append(idx)
predicted_tags.append(tag_name)
# Build result for this image
result_dict = {
"initial_logits": init_logit,
"refined_logits": ref_logit,
"predicted_indices": predicted_indices,
"predicted_tags": predicted_tags,
}
batch_results.append(result_dict)
return batch_results
if __name__ == "__main__":
# Example usage
images = ["images.png"]
results = onnx_inference(
img_paths=images,
onnx_path="camie_refined_no_flash_v15.onnx",
metadata_file="metadata.json",
threshold_json_path="thresholds.json",
mode="balanced", # or "balanced", "high_precision"
target_size=512,
keep_aspect=True
)
for i, res in enumerate(results):
print(f"Image: {images[i]}")
print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}")
# Show first 10 predicted tags (if available)
sample_tags = res['predicted_tags']
print(" Sample predicted tags:", sample_tags)
print() |