Eric P. Nusbaum commited on
Commit
32eca4a
·
1 Parent(s): 196b516

Update Space

Browse files
Files changed (1) hide show
  1. app.py +13 -96
app.py CHANGED
@@ -1,53 +1,3 @@
1
- import os
2
- import numpy as np
3
- import onnx
4
- import onnxruntime
5
- from PIL import Image, ImageDraw, ImageFont
6
- import gradio as gr
7
-
8
- # Constants
9
- PROB_THRESHOLD = 0.5 # Minimum probability to show results
10
- MODEL_PATH = os.path.join("onnx", "model.onnx")
11
- LABELS_PATH = os.path.join("onnx", "labels.txt")
12
-
13
- # Load labels
14
- with open(LABELS_PATH, "r") as f:
15
- LABELS = f.read().strip().split("\n")
16
-
17
- class Model:
18
- def __init__(self, model_filepath):
19
- self.session = onnxruntime.InferenceSession(model_filepath)
20
- assert len(self.session.get_inputs()) == 1
21
- self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
22
- self.input_name = self.session.get_inputs()[0].name
23
- self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
24
- self.session.get_inputs()[0].type, np.float32
25
- )
26
- self.output_names = [o.name for o in self.session.get_outputs()]
27
-
28
- self.is_bgr = False
29
- self.is_range255 = False
30
- onnx_model = onnx.load(model_filepath)
31
- for metadata in onnx_model.metadata_props:
32
- if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
33
- self.is_bgr = True
34
- elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
35
- self.is_range255 = True
36
-
37
- def predict(self, image: Image.Image):
38
- # Preprocess image
39
- image_resized = image.resize(self.input_shape)
40
- input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
41
- input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
42
- if self.is_bgr:
43
- input_array = input_array[:, (2, 1, 0), :, :]
44
- if not self.is_range255:
45
- input_array = input_array / 255.0 # Normalize to [0,1]
46
-
47
- # Run inference
48
- outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
49
- return {name: outputs[i] for i, name in enumerate(self.output_names)}
50
-
51
  def draw_boxes(image: Image.Image, outputs: dict):
52
  draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
53
 
@@ -77,7 +27,7 @@ def draw_boxes(image: Image.Image, outputs: dict):
77
 
78
  # Set label box dimensions
79
  text_width = right - left
80
- text_height = (bottom - top) // 20 # 5% of the bounding box height
81
 
82
  # Calculate label background position
83
  label_top = max(top - text_height - 10, 0)
@@ -89,10 +39,19 @@ def draw_boxes(image: Image.Image, outputs: dict):
89
  fill=(255, 0, 0, 160) # Semi-transparent red
90
  )
91
 
92
- # Use default font for text
93
- font = ImageFont.load_default()
 
 
 
 
 
 
 
 
 
94
 
95
- # Draw text
96
  draw.text(
97
  (label_left + 5, label_top + 5),
98
  text,
@@ -101,45 +60,3 @@ def draw_boxes(image: Image.Image, outputs: dict):
101
  )
102
 
103
  return image
104
-
105
- # Initialize model
106
- model = Model(MODEL_PATH)
107
-
108
- def detect_objects(image):
109
- outputs = model.predict(image)
110
- annotated_image = draw_boxes(image.copy(), outputs)
111
-
112
- # Prepare detection summary
113
- detections = []
114
- boxes = outputs.get('detected_boxes', [])
115
- classes = outputs.get('detected_classes', [])
116
- scores = outputs.get('detected_scores', [])
117
-
118
- for box, cls, score in zip(boxes[0], classes[0], scores[0]):
119
- if score < PROB_THRESHOLD:
120
- continue
121
- label = LABELS[int(cls)]
122
- detections.append(f"{label}: {score:.2f}")
123
-
124
- detection_summary = "\n".join(detections) if detections else "No objects detected."
125
-
126
- return annotated_image, detection_summary
127
-
128
- # Gradio Interface
129
- iface = gr.Interface(
130
- fn=detect_objects,
131
- inputs=gr.Image(type="pil"),
132
- outputs=[
133
- gr.Image(type="pil", label="Detected Objects"),
134
- gr.Textbox(label="Detections")
135
- ],
136
- title="JunkWaxHero - Baseball Card Set Detection (ONNX Model)",
137
- description="Upload an image to itentify the set of the baseball card.",
138
- examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
139
- theme="default", # You can choose other themes if desired
140
- allow_flagging="never" # Disable flagging if not needed
141
- # Removed 'layout' parameter
142
- )
143
-
144
- if __name__ == "__main__":
145
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def draw_boxes(image: Image.Image, outputs: dict):
2
  draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
3
 
 
27
 
28
  # Set label box dimensions
29
  text_width = right - left
30
+ text_height = (bottom - top) // 10 # 10% of the bounding box height
31
 
32
  # Calculate label background position
33
  label_top = max(top - text_height - 10, 0)
 
39
  fill=(255, 0, 0, 160) # Semi-transparent red
40
  )
41
 
42
+ # Dynamically scale font size
43
+ font_size = 10 # Start with a small font size
44
+ font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" # Common path on Linux
45
+
46
+ while True:
47
+ font = ImageFont.truetype(font_path, size=font_size)
48
+ text_bbox = draw.textbbox((0, 0), text, font=font)
49
+ text_pixel_height = text_bbox[3] - text_bbox[1]
50
+ if text_pixel_height >= text_height or font_size > 100: # Cap font size to prevent infinite loops
51
+ break
52
+ font_size += 1
53
 
54
+ # Draw text with the scaled font
55
  draw.text(
56
  (label_left + 5, label_top + 5),
57
  text,
 
60
  )
61
 
62
  return image