File size: 5,833 Bytes
a7ab59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca9b012
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import onnxruntime as ort
import numpy as np
import json
from PIL import Image

# 1) Load ONNX model
session = ort.InferenceSession("camie_tagger_initial_v15.onnx", providers=["CPUExecutionProvider"])

# 2) Preprocess your image (512x512, etc.)
def preprocess_image(img_path):
    """

    Loads and resizes an image to 512x512, converts it to float32 [0..1],

    and returns a (1,3,512,512) NumPy array (NCHW format).

    """
    img = Image.open(img_path).convert("RGB").resize((512, 512))
    x = np.array(img).astype(np.float32) / 255.0
    x = np.transpose(x, (2, 0, 1))  # HWC -> CHW
    x = np.expand_dims(x, 0)        # add batch dimension -> (1,3,512,512)
    return x

# Example input
def load_thresholds(threshold_json_path, mode="balanced"):
    """

    Loads thresholds from the given JSON file, using a particular mode

    (e.g. 'balanced', 'high_precision', 'high_recall') for each category.



    Returns:

        thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }

        fallback_threshold (float): The overall threshold if category not found

    """
    with open(threshold_json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # The fallback threshold from the "overall" section for the chosen mode
    fallback_threshold = data["overall"][mode]["threshold"]

    # Build a dict of thresholds keyed by category
    thresholds_by_category = {}
    if "categories" in data:
        for cat_name, cat_modes in data["categories"].items():
            # If the chosen mode is present for that category, use it;
            # otherwise fall back to the "overall" threshold.
            if mode in cat_modes and "threshold" in cat_modes[mode]:
                thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
            else:
                thresholds_by_category[cat_name] = fallback_threshold

    return thresholds_by_category, fallback_threshold

def inference(

    input_path, 

    output_format="verbose", 

    mode="balanced",

    threshold_json_path="thresholds.json",

    metadata_path="metadata.json"

):
    """

    Run inference on an image using the loaded ONNX model, then apply

    category-wise thresholds from `threshold.json` for the chosen mode.



    Arguments:

        input_path (str)           : Path to the image file for inference.

        output_format (str)        : Either "verbose" or "as_prompt".

        mode (str)                 : "balanced", "high_precision", or "high_recall"

        threshold_json_path (str)  : Path to the JSON file with category thresholds.

        metadata_path (str)        : Path to the metadata JSON file with category info.



    Returns:

        str: The predicted tags in either verbose or comma-separated format.

    """
    # 1) Preprocess
    input_tensor = preprocess_image(input_path)

    # 2) Run inference
    input_name = session.get_inputs()[0].name
    outputs = session.run(None, {input_name: input_tensor})
    initial_logits, refined_logits = outputs  # shape: (1, 70527) each

    # 3) Convert logits to probabilities
    refined_probs = 1 / (1 + np.exp(-refined_logits))  # shape: (1, 70527)

    # 4) Load metadata & retrieve threshold info
    with open(metadata_path, "r", encoding="utf-8") as f:
        metadata = json.load(f)

    idx_to_tag = metadata["idx_to_tag"]  # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
    tag_to_category = metadata.get("tag_to_category", {})
    # Load thresholds from threshold.json using the specified mode
    thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)


    # 5) Collect predictions by category
    results_by_category = {}
    num_tags = refined_probs.shape[1]

    for i in range(num_tags):
        prob = float(refined_probs[0, i])
        tag_name = idx_to_tag[str(i)]  # str(i) because metadata uses string keys
        category = tag_to_category.get(tag_name, "general")

        # Determine the threshold to use for this category
        cat_threshold = thresholds_by_category.get(category, fallback_threshold)

        if prob >= cat_threshold:
            if category not in results_by_category:
                results_by_category[category] = []
            results_by_category[category].append((tag_name, prob))

    # 6) Depending on output_format, produce different return strings
    if output_format == "as_prompt":
        # Flatten all predicted tags across categories
        all_predicted_tags = []
        for cat, tags_list in results_by_category.items():
            # We only need the tag name in as_prompt format
            for tname, tprob in tags_list:
                # convert underscores to spaces
                tag_name_spaces = tname.replace("_", " ")
                all_predicted_tags.append(tag_name_spaces)

        # Create a comma-separated string
        prompt_string = ", ".join(all_predicted_tags)
        return prompt_string

    else:  # "verbose"
        # We'll build a multiline string describing the predictions
        lines = []
        lines.append("Predicted Tags by Category:\n")
        for cat, tags_list in results_by_category.items():
            lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
            # Sort descending by probability
            for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
                lines.append(f"  Tag: {tname:30s}  Prob: {tprob:.4f}")
            lines.append("")  # blank line after each category
        # Join lines with newlines
        verbose_output = "\n".join(lines)
        return verbose_output

if __name__ == "__main__":
    result = inference("", output_format="as_prompt")
    print(result)