File size: 7,511 Bytes
bc492d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7ab59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc492d2
 
a7ab59e
 
bc492d2
a7ab59e
 
 
 
 
 
 
 
bc492d2
a7ab59e
bc492d2
 
 
a7ab59e
 
bc492d2
 
 
 
 
a7ab59e
bc492d2
 
 
 
 
a7ab59e
 
 
 
 
bc492d2
 
 
 
a7ab59e
bc492d2
a7ab59e
bc492d2
 
 
a7ab59e
bc492d2
 
a7ab59e
bc492d2
a7ab59e
bc492d2
 
 
 
a7ab59e
 
 
 
bc492d2
a7ab59e
 
 
 
 
bc492d2
a7ab59e
 
 
bc492d2
 
 
 
 
a7ab59e
 
bc492d2
 
 
 
 
 
 
a7ab59e
 
 
 
 
 
 
 
 
 
bc492d2
 
 
 
a7ab59e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import onnxruntime as ort
import numpy as np
import json
from PIL import Image

def preprocess_image(img_path, target_size=512, keep_aspect=True):
    """

    Load an image from img_path, convert to RGB,

    and resize/pad to (target_size, target_size).

    Scales pixel values to [0,1] and returns a (1,3,target_size,target_size) float32 array.

    """
    img = Image.open(img_path).convert("RGB")
    
    if keep_aspect:
        # Preserve aspect ratio, pad black
        w, h = img.size
        aspect = w / h
        if aspect > 1:
            new_w = target_size
            new_h = int(new_w / aspect)
        else:
            new_h = target_size
            new_w = int(new_h * aspect)

        # Resize with Lanczos
        img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
        # Pad to a square
        background = Image.new("RGB", (target_size, target_size), (0, 0, 0))
        paste_x = (target_size - new_w) // 2
        paste_y = (target_size - new_h) // 2
        background.paste(img, (paste_x, paste_y))
        img = background
    else:
        # simple direct resize to 512x512
        img = img.resize((target_size, target_size), Image.Resampling.LANCZOS)
    
    # Convert to numpy array
    arr = np.array(img).astype("float32") / 255.0  # scale to [0,1]
    # Transpose from HWC -> CHW
    arr = np.transpose(arr, (2, 0, 1))
    # Add batch dimension: (1,3,512,512)
    arr = np.expand_dims(arr, axis=0)
    return arr

# Example input
def load_thresholds(threshold_json_path, mode="balanced"):
    """

    Loads thresholds from the given JSON file, using a particular mode

    (e.g. 'balanced', 'high_precision', 'high_recall') for each category.



    Returns:

        thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }

        fallback_threshold (float): The overall threshold if category not found

    """
    with open(threshold_json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # The fallback threshold from the "overall" section for the chosen mode
    fallback_threshold = data["overall"][mode]["threshold"]

    # Build a dict of thresholds keyed by category
    thresholds_by_category = {}
    if "categories" in data:
        for cat_name, cat_modes in data["categories"].items():
            # If the chosen mode is present for that category, use it;
            # otherwise fall back to the "overall" threshold.
            if mode in cat_modes and "threshold" in cat_modes[mode]:
                thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
            else:
                thresholds_by_category[cat_name] = fallback_threshold

    return thresholds_by_category, fallback_threshold
def onnx_inference(

    img_paths,

    onnx_path="camie_refined_no_flash.onnx",

    metadata_file="metadata.json",

    threshold_json_path="thresholds.json",

    mode="balanced",

    target_size=512,

    keep_aspect=True

):
    """

    Loads the ONNX model, runs inference on a list of image paths,

    and applies category-wise thresholds from threshold.json (per the chosen mode).



    Args:

      img_paths           : List of paths to images.

      onnx_path           : Path to the exported ONNX model file.

      metadata_file       : Path to metadata.json that contains idx_to_tag, tag_to_category, etc.

      threshold_json_path : Path to thresholds.json containing category-wise threshold info.

      mode                : "balanced", "high_precision", or "high_recall".

      target_size         : Final size of preprocessed images (512 by default).

      keep_aspect         : If True, preserve aspect ratio when resizing, pad with black.



    Returns:

      A list of dicts, one per input image, each containing:

        {

          "initial_logits": np.ndarray of shape (N_tags,),

          "refined_logits": np.ndarray of shape (N_tags,),

          "predicted_indices": list of tag indices that exceeded threshold,

          "predicted_tags": list of predicted tag strings,

          ...

        }

    """
    # 1) Initialize ONNX runtime session
    session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
    # For GPU usage, you could do e.g.:
    # session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])

    # 2) Pre-load metadata
    with open(metadata_file, "r", encoding="utf-8") as f:
        metadata = json.load(f)
    idx_to_tag = metadata["idx_to_tag"]          # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
    tag_to_category = metadata.get("tag_to_category", {})

    # Load thresholds from thresholds.json using the specified mode
    thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)

    # 3) Preprocess each image into a batch
    batch_tensors = []
    for img_path in img_paths:
        x = preprocess_image(img_path, target_size=target_size, keep_aspect=keep_aspect)
        batch_tensors.append(x)
    # Concatenate along the batch dimension => shape (batch_size, 3, H, W)
    batch_input = np.concatenate(batch_tensors, axis=0)

    # 4) Run inference
    input_name = session.get_inputs()[0].name  # typically "image" or "input"
    outputs = session.run(None, {input_name: batch_input})
    # Typically we get [initial_tags, refined_tags] as output
    initial_preds, refined_preds = outputs  # shapes => (batch_size, N_tags)

    # 5) Convert logits -> probabilities -> apply category-specific thresholds
    batch_results = []
    for i in range(initial_preds.shape[0]):
        init_logit = initial_preds[i, :]   # shape (N_tags,)
        ref_logit = refined_preds[i, :]    # shape (N_tags,)
        ref_prob  = 1.0 / (1.0 + np.exp(-ref_logit))  # shape (N_tags,)

        predicted_indices = []
        predicted_tags = []

        # Check each tag against the category threshold
        for idx in range(ref_logit.shape[0]):
            tag_name = idx_to_tag[str(idx)]  # Convert index->string->tag name
            category = tag_to_category.get(tag_name, "general")  # fallback to "general" if missing
            cat_threshold = thresholds_by_category.get(category, fallback_threshold)

            if ref_prob[idx] >= cat_threshold:
                predicted_indices.append(idx)
                predicted_tags.append(tag_name)

        # Build result for this image
        result_dict = {
            "initial_logits": init_logit,
            "refined_logits": ref_logit,
            "predicted_indices": predicted_indices,
            "predicted_tags": predicted_tags,
        }
        batch_results.append(result_dict)

    return batch_results

if __name__ == "__main__":
    # Example usage
    images = ["images.png"]
    results = onnx_inference(
        img_paths=images,
        onnx_path="camie_refined_no_flash_v15.onnx",
        metadata_file="metadata.json",
        threshold_json_path="thresholds.json",
        mode="balanced",     # or "balanced", "high_precision"
        target_size=512,
        keep_aspect=True
    )

    for i, res in enumerate(results):
        print(f"Image: {images[i]}")
        print(f"  # of predicted tags above threshold: {len(res['predicted_indices'])}")
        # Show first 10 predicted tags (if available)
        sample_tags = res['predicted_tags']
        print("  Sample predicted tags:", sample_tags)
        print()