File size: 6,131 Bytes
de40de9
 
 
 
 
 
cb58d33
de40de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb58d33
 
de40de9
cb58d33
 
 
 
 
 
de40de9
f61c335
 
633e7c4
f61c335
 
 
 
 
 
 
 
4e8cd1a
f61c335
 
 
 
 
 
 
 
 
 
e7ef62f
f61c335
 
 
 
196b516
e7ef62f
0e2b9ca
e7ef62f
f61c335
 
 
 
 
 
 
 
 
 
32eca4a
 
 
 
 
 
 
 
9bad826
32eca4a
 
bf53212
32eca4a
196b516
 
 
 
 
 
633e7c4
7a9ce47
de40de9
 
 
 
 
cb58d33
de40de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb58d33
de40de9
 
 
076eeb1
de40de9
 
 
 
 
 
 
9bad826
076eeb1
 
 
 
de40de9
 
 
 
 
 
cb58d33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import numpy as np
import onnx
import onnxruntime
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import time  # Import time for benchmarking

# Constants
PROB_THRESHOLD = 0.5  # Minimum probability to show results
MODEL_PATH = os.path.join("onnx", "model.onnx")
LABELS_PATH = os.path.join("onnx", "labels.txt")

# Load labels
with open(LABELS_PATH, "r") as f:
    LABELS = f.read().strip().split("\n")

class Model:
    def __init__(self, model_filepath):
        self.session = onnxruntime.InferenceSession(model_filepath)
        assert len(self.session.get_inputs()) == 1
        self.input_shape = self.session.get_inputs()[0].shape[2:]  # (H, W)
        self.input_name = self.session.get_inputs()[0].name
        self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
            self.session.get_inputs()[0].type, np.float32
        )
        self.output_names = [o.name for o in self.session.get_outputs()]

        self.is_bgr = False
        self.is_range255 = False
        onnx_model = onnx.load(model_filepath)
        for metadata in onnx_model.metadata_props:
            if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
                self.is_bgr = True
            elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
                self.is_range255 = True

    def predict(self, image: Image.Image):
        # Preprocess image
        image_resized = image.resize(self.input_shape)
        input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
        input_array = input_array.transpose((0, 3, 1, 2))  # (N, C, H, W)
        if self.is_bgr:
            input_array = input_array[:, (2, 1, 0), :, :]
        if not self.is_range255:
            input_array = input_array / 255.0  # Normalize to [0,1]

        # Run inference with benchmarking
        start_time = time.time()  # Start timing
        outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
        end_time = time.time()  # End timing

        execution_time = (end_time - start_time) * 1000  # Convert to milliseconds
        print(f"Inference time: {execution_time:.2f} ms")

        return {name: outputs[i] for i, name in enumerate(self.output_names)}, execution_time

def draw_boxes(image: Image.Image, outputs: dict):
    draw = ImageDraw.Draw(image, "RGBA")  # Use RGBA for transparency

    # Dynamic font size based on image dimensions
    image_width, image_height = image.size
    boxes = outputs.get('detected_boxes', [])
    classes = outputs.get('detected_classes', [])
    scores = outputs.get('detected_scores', [])

    for box, cls, score in zip(boxes[0], classes[0], scores[0]):
        if score < PROB_THRESHOLD:
            continue
        label = LABELS[int(cls)]

        # Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1]
        ymin, xmin, ymax, xmax = box
        left = xmin * image_width
        right = xmax * image_width
        top = ymin * image_height
        bottom = ymax * image_height

        # Draw bounding box
        draw.rectangle([left, top, right, bottom], outline="red", width=4)

        # Prepare label text
        text = f"{label}: {score:.2f}"

        # Set label box dimensions
        text_width = right - left
        text_height = (bottom - top) // 20  # 5% of the box heights

        # Calculate label background position
        label_top = max(top - text_height - 10, 0)
        label_left = left

        # Draw semi-transparent rectangle behind text
        draw.rectangle(
            [label_left, label_top, label_left + text_width + 10, label_top + text_height + 10],
            fill=(255, 0, 0, 160)  # Semi-transparent red
        )

        # Dynamically scale font size
        font_size = 10  # Start with a small font size
        font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"  # Common path on Linux

        while True:
            font = ImageFont.truetype(font_path, size=font_size)
            text_bbox = draw.textbbox((0, 0), text, font=font)
            text_pixel_height = text_bbox[3] - text_bbox[1]
            if text_pixel_height >= text_height or font_size > 200:  # Cap font size to prevent infinite loops
                break
            font_size += 1

        # Draw text with the scaled font
        draw.text(
            (label_left + 5, label_top + 5),
            text,
            fill="black",
            font=font
        )

    return image

# Initialize model
model = Model(MODEL_PATH)

def detect_objects(image):
    outputs, execution_time = model.predict(image)
    annotated_image = draw_boxes(image.copy(), outputs)

    # Prepare detection summary
    detections = []
    boxes = outputs.get('detected_boxes', [])
    classes = outputs.get('detected_classes', [])
    scores = outputs.get('detected_scores', [])

    for box, cls, score in zip(boxes[0], classes[0], scores[0]):
        if score < PROB_THRESHOLD:
            continue
        label = LABELS[int(cls)]
        detections.append(f"{label}: {score:.2f}")

    detection_summary = "\n".join(detections) if detections else "No objects detected."
    detection_summary += f"\n\nInference Time: {execution_time:.2f} ms"

    return annotated_image, detection_summary

# Enhanced Gradio Interface with Links to Model Card and Repository
iface = gr.Interface(
    fn=detect_objects,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Detected Objects"),
        gr.Textbox(label="Detections")
    ],
    title="JunkWaxHero โšพ - Baseball Card Set Detection (ONNX Model)",
    description=(
        "Upload an image to identify the set of the baseball card (1980-1999).\n\n"
        "[๐Ÿ”— Model Card & Repository](https://huggingface.co/enusbaum/JunkWaxHero)"
    ),
    examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
    theme="default",  # You can choose other themes if desired
    allow_flagging="never"  # Disable flagging if not needed
)

if __name__ == "__main__":
    iface.launch()