Eric P. Nusbaum commited on
Commit
ae5135e
·
1 Parent(s): 029bb24

Update to use ONNX

Browse files
Files changed (1) hide show
  1. app.py +88 -214
app.py CHANGED
@@ -1,241 +1,115 @@
1
- import gradio as gr
2
- import tensorflow as tf
3
  import numpy as np
 
 
4
  from PIL import Image, ImageDraw, ImageFont
5
- import os
6
-
7
- # Suppress TensorFlow logging for cleaner logs
8
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
9
 
10
- # Disable GPU usage explicitly to prevent TensorFlow from attempting to access GPU libraries
11
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
 
 
12
 
13
  # Load labels
14
- labels_path = os.path.join('tensorflow', 'labels.txt')
15
- if not os.path.exists(labels_path):
16
- raise FileNotFoundError(f"Labels file not found at {labels_path}")
17
- with open(labels_path, 'r') as f:
18
- labels = f.read().splitlines()
19
-
20
- # Function to load the frozen TensorFlow graph
21
- def load_frozen_graph(pb_file_path):
22
- if not os.path.exists(pb_file_path):
23
- raise FileNotFoundError(f"Model file not found at {pb_file_path}")
24
- with tf.io.gfile.GFile(pb_file_path, 'rb') as f:
25
- graph_def = tf.compat.v1.GraphDef()
26
- graph_def.ParseFromString(f.read())
27
-
28
- with tf.Graph().as_default() as graph:
29
- tf.import_graph_def(graph_def, name='')
30
- return graph
31
-
32
- # Load the TensorFlow model
33
- MODEL_DIR = 'tensorflow'
34
- MODEL_PATH = os.path.join(MODEL_DIR, 'model.pb')
35
- graph = load_frozen_graph(MODEL_PATH)
36
- sess = tf.compat.v1.Session(graph=graph)
37
-
38
- # Define tensor names based on your model's outputs
39
- try:
40
- input_tensor = graph.get_tensor_by_name('image_tensor:0')
41
- detected_boxes = graph.get_tensor_by_name('detected_boxes:0')
42
- detected_classes = graph.get_tensor_by_name('detected_classes:0')
43
- detected_scores = graph.get_tensor_by_name('detected_scores:0')
44
- except KeyError as e:
45
- raise KeyError(f"Tensor not found in the graph: {e}")
46
-
47
- # Define the target size based on your model's expected input
48
- TARGET_WIDTH = 320
49
- TARGET_HEIGHT = 320
50
-
51
- def preprocess_image(image):
52
- """
53
- Preprocess the input image:
54
- - Resize to target dimensions
55
- - Convert to numpy array
56
- - Normalize pixel values
57
- - Convert RGB to BGR if required by the model
58
- """
59
- image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
60
- image_np = np.array(image).astype(np.float32)
61
- image_np = image_np / 255.0 # Normalize to [0,1]
62
- if image_np.shape[-1] == 3:
63
- # Convert RGB to BGR if required by your model
64
- image_np = image_np[..., (2, 1, 0)]
65
- image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
66
- return image_np
67
-
68
- def draw_boxes(image, boxes, classes, scores, threshold=0.5):
69
- """
70
- Draw bounding boxes and labels on the image.
71
- Args:
72
- image (PIL.Image): The original image.
73
- boxes (np.array): Array of bounding boxes.
74
- classes (np.array): Array of class IDs.
75
- scores (np.array): Array of confidence scores.
76
- threshold (float): Confidence threshold to filter detections.
77
- Returns:
78
- PIL.Image: Annotated image.
79
- """
80
  draw = ImageDraw.Draw(image)
81
  try:
82
- font = ImageFont.truetype("arial.ttf", 15)
83
  except IOError:
84
  font = ImageFont.load_default()
85
 
86
- # If there are no detections at all
87
- if boxes.shape[0] == 0 or boxes.shape[1] == 0:
88
- # Return the original image without annotation
89
- return image
90
 
91
- # Otherwise, proceed to draw bounding boxes
92
  for box, cls, score in zip(boxes[0], classes[0], scores[0]):
93
- if score < threshold:
94
  continue
95
-
96
- # Convert box coordinates from normalized to absolute
97
  ymin, xmin, ymax, xmax = box
98
- left = xmin * image.width
99
- right = xmax * image.width
100
- top = ymin * image.height
101
- bottom = ymax * image.height
102
-
103
- # Draw rectangle
104
- draw.rectangle([(left, top), (right, bottom)], outline="red", width=2)
105
-
106
- # Prepare label
107
- cls_index = int(cls) - 1 # If your classes are 1-indexed
108
- if cls_index < 0 or cls_index >= len(labels):
109
- label_str = f"cls_{int(cls)}: {score:.2f}"
110
- else:
111
- label_str = f"{labels[cls_index]}: {score:.2f}"
112
-
113
- # Calculate text size using textbbox
114
- text_bbox = draw.textbbox((0, 0), label_str, font=font)
115
- text_width = text_bbox[2] - text_bbox[0]
116
- text_height = text_bbox[3] - text_bbox[1]
117
-
118
- # Draw label background
119
- draw.rectangle([(left, top - text_height - 4),
120
- (left + text_width + 4, top)], fill="red")
121
-
122
- # Draw text
123
- draw.text((left + 2, top - text_height - 2),
124
- label_str, fill="white", font=font)
125
-
126
  return image
127
 
128
- def predict(image):
129
- """
130
- Perform inference on the input image and return the annotated image.
131
- Args:
132
- image (PIL.Image): Uploaded image.
133
- Returns:
134
- PIL.Image: Annotated image with bounding boxes and labels.
135
- """
136
- try:
137
- # Preprocess the image
138
- input_array = preprocess_image(image)
139
 
140
- # Debug prints
141
- print(f"[DEBUG] Input shape to model: {input_array.shape}")
 
142
 
143
- # Run inference
144
- boxes, classes, scores = sess.run(
145
- [detected_boxes, detected_classes, detected_scores],
146
- feed_dict={input_tensor: input_array}
147
- )
148
 
149
- # Debug prints
150
- print(f"[DEBUG] boxes shape: {boxes.shape}, classes shape: {classes.shape}, scores shape: {scores.shape}")
151
-
152
- # Check if the output arrays have zero detections
153
- # e.g. boxes could have shape (1, 0, 4) if no detections are found
154
- if boxes.size == 0 or classes.size == 0 or scores.size == 0:
155
- print("[DEBUG] No detections returned by the model.")
156
- return _draw_no_detection_message(image)
157
-
158
- # Annotate the image with bounding boxes and labels
159
- annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5)
160
- print("[DEBUG] Annotation completed.")
161
- return annotated_image
162
-
163
- except Exception as e:
164
- # Log the exception for debugging
165
- print(f"Exception during prediction: {e}")
166
-
167
- # Return an error image with the error message
168
- return _draw_error_message()
169
-
170
- def _draw_no_detection_message(image):
171
- """Draws a simple 'No detections found' message on the image."""
172
- draw = ImageDraw.Draw(image)
173
- try:
174
- font = ImageFont.truetype("arial.ttf", 20)
175
- except IOError:
176
- font = ImageFont.load_default()
177
- message = "No detections found."
178
- text_bbox = draw.textbbox((0, 0), message, font=font)
179
- text_width = text_bbox[2] - text_bbox[0]
180
- text_height = text_bbox[3] - text_bbox[1]
181
-
182
- # Center the message
183
- x = (image.width - text_width) / 2
184
- y = (image.height - text_height) / 2
185
- draw.rectangle(
186
- [(x - 10, y - 10), (x + text_width + 10, y + text_height + 10)],
187
- fill="black"
188
- )
189
- draw.text((x, y), message, fill="white", font=font)
190
- return image
191
 
192
- def _draw_error_message():
193
- """Creates a red image with a centered error message."""
194
- error_image = Image.new('RGB', (500, 500), color=(255, 0, 0))
195
- draw = ImageDraw.Draw(error_image)
196
- try:
197
- font = ImageFont.truetype("arial.ttf", 20)
198
- except IOError:
199
- font = ImageFont.load_default()
200
- error_text = "Error during prediction."
201
- text_bbox = draw.textbbox((0, 0), error_text, font=font)
202
- text_width = text_bbox[2] - text_bbox[0]
203
- text_height = text_bbox[3] - text_bbox[1]
204
-
205
- draw.rectangle(
206
- [
207
- ((500 - text_width) / 2 - 10, (500 - text_height) / 2 - 10),
208
- ((500 + text_width) / 2 + 10, (500 + text_height) / 2 + 10)
209
- ],
210
- fill="black"
211
- )
212
- draw.text(
213
- ((500 - text_width) / 2, (500 - text_height) / 2),
214
- error_text,
215
- fill="white",
216
- font=font
217
- )
218
- return error_image
219
-
220
- # Define Gradio interface using the new API
221
- title = "JunkWaxHero 🦸‍♂️ - Baseball Card Set Identifier"
222
- description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy."
223
-
224
- # Verify that example images exist to prevent FileNotFoundError
225
- example_images = ["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"]
226
- valid_examples = [img for img in example_images if os.path.exists(img)]
227
- if not valid_examples:
228
- valid_examples = None # Remove examples if none exist
229
 
 
230
  iface = gr.Interface(
231
- fn=predict,
232
  inputs=gr.Image(type="pil"),
233
- outputs=gr.Image(type="pil"),
234
- title=title,
235
- description=description,
236
- examples=valid_examples,
237
- flagging_mode="never" # Use new Gradio parameter
238
  )
239
 
240
  if __name__ == "__main__":
241
  iface.launch()
 
 
1
+ import os
 
2
  import numpy as np
3
+ import onnx
4
+ import onnxruntime
5
  from PIL import Image, ImageDraw, ImageFont
6
+ import gradio as gr
 
 
 
7
 
8
+ # Constants
9
+ PROB_THRESHOLD = 0.5 # Minimum probability to show results
10
+ MODEL_PATH = os.path.join("onnx", "model.onnx")
11
+ LABELS_PATH = os.path.join("onnx", "labels.txt")
12
 
13
  # Load labels
14
+ with open(LABELS_PATH, "r") as f:
15
+ LABELS = f.read().strip().split("\n")
16
+
17
+ class Model:
18
+ def __init__(self, model_filepath):
19
+ self.session = onnxruntime.InferenceSession(model_filepath)
20
+ assert len(self.session.get_inputs()) == 1
21
+ self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
22
+ self.input_name = self.session.get_inputs()[0].name
23
+ self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
24
+ self.session.get_inputs()[0].type, np.float32
25
+ )
26
+ self.output_names = [o.name for o in self.session.get_outputs()]
27
+
28
+ self.is_bgr = False
29
+ self.is_range255 = False
30
+ onnx_model = onnx.load(model_filepath)
31
+ for metadata in onnx_model.metadata_props:
32
+ if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
33
+ self.is_bgr = True
34
+ elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
35
+ self.is_range255 = True
36
+
37
+ def predict(self, image: Image.Image):
38
+ # Preprocess image
39
+ image_resized = image.resize(self.input_shape)
40
+ input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
41
+ input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
42
+ if self.is_bgr:
43
+ input_array = input_array[:, (2, 1, 0), :, :]
44
+ if not self.is_range255:
45
+ input_array = input_array / 255.0 # Normalize to [0,1]
46
+
47
+ # Run inference
48
+ outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
49
+ return {name: outputs[i] for i, name in enumerate(self.output_names)}
50
+
51
+ def draw_boxes(image: Image.Image, outputs: dict):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  draw = ImageDraw.Draw(image)
53
  try:
54
+ font = ImageFont.truetype("arial.ttf", size=16)
55
  except IOError:
56
  font = ImageFont.load_default()
57
 
58
+ boxes = outputs.get('detected_boxes', [])
59
+ classes = outputs.get('detected_classes', [])
60
+ scores = outputs.get('detected_scores', [])
 
61
 
 
62
  for box, cls, score in zip(boxes[0], classes[0], scores[0]):
63
+ if score < PROB_THRESHOLD:
64
  continue
65
+ label = LABELS[int(cls)]
66
+ # Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1]
67
  ymin, xmin, ymax, xmax = box
68
+ width, height = image.size
69
+ left = xmin * width
70
+ right = xmax * width
71
+ top = ymin * height
72
+ bottom = ymax * height
73
+ draw.rectangle([left, top, right, bottom], outline="red", width=2)
74
+ text = f"{label}: {score:.2f}"
75
+ text_size = draw.textsize(text, font=font)
76
+ draw.rectangle([left, top - text_size[1], left + text_size[0], top], fill="red")
77
+ draw.text((left, top - text_size[1]), text, fill="white", font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return image
79
 
80
+ # Initialize model
81
+ model = Model(MODEL_PATH)
 
 
 
 
 
 
 
 
 
82
 
83
+ def detect_objects(image):
84
+ outputs = model.predict(image)
85
+ annotated_image = draw_boxes(image.copy(), outputs)
86
 
87
+ # Prepare detection summary
88
+ detections = []
89
+ boxes = outputs.get('detected_boxes', [])
90
+ classes = outputs.get('detected_classes', [])
91
+ scores = outputs.get('detected_scores', [])
92
 
93
+ for box, cls, score in zip(boxes[0], classes[0], scores[0]):
94
+ if score < PROB_THRESHOLD:
95
+ continue
96
+ label = LABELS[int(cls)]
97
+ detections.append(f"{label}: {score:.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ detection_summary = "\n".join(detections) if detections else "No objects detected."
100
+
101
+ return annotated_image, detection_summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ # Gradio Interface
104
  iface = gr.Interface(
105
+ fn=detect_objects,
106
  inputs=gr.Image(type="pil"),
107
+ outputs=[gr.Image(type="pil", label="Detected Objects"), gr.Textbox(label="Detections")],
108
+ title="Object Detection with ONNX Model",
109
+ description="Upload an image to detect objects using the ONNX model.",
110
+ examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"]
 
111
  )
112
 
113
  if __name__ == "__main__":
114
  iface.launch()
115
+