Aumkeshchy2003 commited on
Commit
2420aaa
·
verified ·
1 Parent(s): a83113c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -56
app.py CHANGED
@@ -1,83 +1,274 @@
1
- import cv2
2
  import torch
3
  import numpy as np
4
  import gradio as gr
 
5
  import time
6
  import os
 
 
7
  from pathlib import Path
8
- import onnxruntime as ort
9
 
10
- # Set device for ONNX Runtime
11
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
12
- session = ort.InferenceSession("models/yolov5n.onnx", providers=providers)
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Load model class names
15
- class_names = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light"] # Modify based on model
 
 
 
16
 
17
- # Generate random colors for classes
 
 
 
 
 
 
 
18
  np.random.seed(42)
19
- colors = np.random.uniform(0, 255, size=(len(class_names), 3))
 
 
 
 
 
20
 
21
- def preprocess(image):
22
- image = cv2.resize(image, (640, 640))
23
- image = image.transpose((2, 0, 1)) / 255.0 # Normalize
24
- image = np.expand_dims(image, axis=0).astype(np.float32)
25
- return image
26
 
27
  def detect_objects(image):
 
 
 
 
 
 
28
  start_time = time.time()
29
- image_input = preprocess(image)
30
- outputs = session.run(None, {session.get_inputs()[0].name: image_input})
31
- detections = outputs[0][0]
32
  output_image = image.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- for det in detections:
35
- x1, y1, x2, y2, conf, cls = map(int, det[:6])
36
- if conf > 0.6: # Confidence threshold
37
- color = colors[cls].tolist()
38
- cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2)
39
- label = f"{class_names[cls]} {conf:.2f}"
40
- cv2.putText(output_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
 
41
 
42
- fps = 1 / (time.time() - start_time)
43
- cv2.putText(output_image, f"FPS: {fps:.2f}", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
44
  return output_image
45
 
46
- def real_time_detection():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  cap = cv2.VideoCapture(0)
48
- cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
49
- cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
50
- cap.set(cv2.CAP_PROP_FPS, 60)
51
 
52
- while cap.isOpened():
53
- start_time = time.time()
54
- ret, frame = cap.read()
55
- if not ret:
56
- break
57
- output_frame = detect_objects(frame)
58
- cv2.imshow("Real-Time Object Detection", output_frame)
59
- if cv2.waitKey(1) & 0xFF == ord('q'):
60
- break
61
- print(f"FPS: {1 / (time.time() - start_time):.2f}")
62
- cap.release()
63
- cv2.destroyAllWindows()
64
-
65
- with gr.Blocks(title="YOLOv5 Real-Time Object Detection") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  gr.Markdown("""
67
- # Real-Time Object Detection with YOLOv5
68
- **Upload an image or run real-time detection**
69
  """)
70
 
71
- with gr.Row():
72
- with gr.Column():
73
- input_image = gr.Image(label="Upload Image", type="numpy")
74
- detect_button = gr.Button("Detect Objects")
75
- start_rt_button = gr.Button("Start Real-Time Detection")
 
 
76
 
77
- with gr.Column():
78
- output_image = gr.Image(label="Detection Results", type="numpy")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- detect_button.click(detect_objects, inputs=input_image, outputs=output_image)
81
- start_rt_button.click(lambda: real_time_detection(), None, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- demo.launch()
 
 
 
1
  import torch
2
  import numpy as np
3
  import gradio as gr
4
+ import cv2
5
  import time
6
  import os
7
+ import threading
8
+ from queue import Queue
9
  from pathlib import Path
 
10
 
11
+ # Create cache directory for models
12
+ os.makedirs("models", exist_ok=True)
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ print(f"Using device: {device}")
16
+
17
+ model_path = Path("models/yolov5x.pt")
18
+ if model_path.exists():
19
+ print(f"Loading model from cache: {model_path}")
20
+ model = torch.hub.load("ultralytics/yolov5", "yolov5x", pretrained=True, source="local", path=str(model_path)).to(device)
21
+ else:
22
+ print("Downloading YOLOv5x model and caching...")
23
+ model = torch.hub.load("ultralytics/yolov5", "yolov5x", pretrained=True).to(device)
24
+ torch.save(model.state_dict(), model_path)
25
 
26
+ # Model configurations for better performance
27
+ model.conf = 0.5 # Slightly lower confidence threshold for real-time
28
+ model.iou = 0.45 # Slightly lower IOU threshold for real-time
29
+ model.classes = None # Detect all classes
30
+ model.max_det = 20 # Limit detections for speed
31
 
32
+ if device.type == "cuda":
33
+ model.half() # Half precision for CUDA
34
+ else:
35
+ torch.set_num_threads(os.cpu_count())
36
+
37
+ model.eval()
38
+
39
+ # Precompute colors for bounding boxes
40
  np.random.seed(42)
41
+ colors = np.random.uniform(0, 255, size=(len(model.names), 3))
42
+
43
+ # Performance tracking
44
+ total_inference_time = 0
45
+ inference_count = 0
46
+ fps_queue = Queue(maxsize=30) # Store last 30 FPS values for smoothing
47
 
48
+ # Threading variables
49
+ processing_lock = threading.Lock()
50
+ stop_event = threading.Event()
51
+ frame_queue = Queue(maxsize=2) # Small queue to avoid lag
52
+ result_queue = Queue(maxsize=2)
53
 
54
  def detect_objects(image):
55
+ """Process a single image for object detection"""
56
+ global total_inference_time, inference_count
57
+
58
+ if image is None:
59
+ return None
60
+
61
  start_time = time.time()
 
 
 
62
  output_image = image.copy()
63
+ input_size = 640
64
+
65
+ # Optimize input for inference
66
+ with torch.no_grad():
67
+ results = model(image, size=input_size)
68
+
69
+ inference_time = time.time() - start_time
70
+ total_inference_time += inference_time
71
+ inference_count += 1
72
+ avg_inference_time = total_inference_time / inference_count
73
+
74
+ detections = results.pred[0].cpu().numpy()
75
+
76
+ # Draw detections
77
+ for *xyxy, conf, cls in detections:
78
+ x1, y1, x2, y2 = map(int, xyxy)
79
+ class_id = int(cls)
80
+ color = colors[class_id].tolist()
81
+
82
+ # Bounding box
83
+ cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 3, lineType=cv2.LINE_AA)
84
+
85
+ # Label with class name and confidence
86
+ label = f"{model.names[class_id]} {conf:.2f}"
87
+ font_scale, font_thickness = 0.9, 2
88
+ (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)
89
+
90
+ cv2.rectangle(output_image, (x1, y1 - h - 10), (x1 + w + 10, y1), color, -1)
91
+ cv2.putText(output_image, label, (x1 + 5, y1 - 5),
92
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), font_thickness, lineType=cv2.LINE_AA)
93
+
94
+ fps = 1 / inference_time
95
 
96
+ # Stylish FPS display
97
+ overlay = output_image.copy()
98
+ cv2.rectangle(overlay, (10, 10), (300, 80), (0, 0, 0), -1)
99
+ output_image = cv2.addWeighted(overlay, 0.6, output_image, 0.4, 0)
100
+ cv2.putText(output_image, f"FPS: {fps:.2f}", (20, 40),
101
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, lineType=cv2.LINE_AA)
102
+ cv2.putText(output_image, f"Avg FPS: {1/avg_inference_time:.2f}", (20, 70),
103
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, lineType=cv2.LINE_AA)
104
 
 
 
105
  return output_image
106
 
107
+ def process_frame_thread():
108
+ """Background thread for processing frames"""
109
+ while not stop_event.is_set():
110
+ if not frame_queue.empty():
111
+ frame = frame_queue.get()
112
+
113
+ # Skip if there's a processing lock (from image upload)
114
+ if processing_lock.locked():
115
+ result_queue.put(frame) # Return unprocessed frame
116
+ continue
117
+
118
+ # Process the frame
119
+ with torch.no_grad(): # Ensure no gradients for inference
120
+ input_size = 384 # Smaller size for real-time processing
121
+ results = model(frame, size=input_size)
122
+
123
+ # Calculate FPS
124
+ inference_time = time.time() - frame.get('timestamp', time.time())
125
+ current_fps = 1 / inference_time if inference_time > 0 else 30
126
+
127
+ # Update rolling FPS average
128
+ fps_queue.put(current_fps)
129
+ avg_fps = sum(list(fps_queue.queue)) / fps_queue.qsize()
130
+
131
+ # Draw detections
132
+ output = frame['image'].copy()
133
+ detections = results.pred[0].cpu().numpy()
134
+
135
+ for *xyxy, conf, cls in detections:
136
+ x1, y1, x2, y2 = map(int, xyxy)
137
+ class_id = int(cls)
138
+ color = colors[class_id].tolist()
139
+
140
+ # Draw rectangle and label
141
+ cv2.rectangle(output, (x1, y1), (x2, y2), color, 2, lineType=cv2.LINE_AA)
142
+
143
+ label = f"{model.names[class_id]} {conf:.2f}"
144
+ font_scale, font_thickness = 0.6, 1 # Smaller for real-time
145
+ (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)
146
+
147
+ cv2.rectangle(output, (x1, y1 - h - 5), (x1 + w + 5, y1), color, -1)
148
+ cv2.putText(output, label, (x1 + 3, y1 - 3),
149
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), font_thickness, lineType=cv2.LINE_AA)
150
+
151
+ # Add FPS counter
152
+ cv2.rectangle(output, (10, 10), (210, 80), (0, 0, 0), -1)
153
+ cv2.putText(output, f"FPS: {current_fps:.1f}", (20, 40),
154
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2, lineType=cv2.LINE_AA)
155
+ cv2.putText(output, f"Avg FPS: {avg_fps:.1f}", (20, 70),
156
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2, lineType=cv2.LINE_AA)
157
+
158
+ # Put the processed frame in the result queue
159
+ result_queue.put({'image': output, 'fps': current_fps})
160
+ else:
161
+ time.sleep(0.001) # Small sleep to prevent CPU spinning
162
+
163
+ def webcam_feed():
164
+ """Generator function for webcam feed"""
165
+ # Start the processing thread if not already running
166
+ if not any(thread.name == "frame_processor" for thread in threading.enumerate()):
167
+ stop_event.clear()
168
+ processor = threading.Thread(target=process_frame_thread, name="frame_processor", daemon=True)
169
+ processor.start()
170
+
171
+ # Open webcam
172
  cap = cv2.VideoCapture(0)
173
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
174
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
 
175
 
176
+ try:
177
+ while True:
178
+ success, frame = cap.read()
179
+ if not success:
180
+ break
181
+
182
+ # Put frame in queue for processing
183
+ if not frame_queue.full():
184
+ frame_queue.put({'image': frame, 'timestamp': time.time()})
185
+
186
+ # Get processed frame from result queue
187
+ if not result_queue.empty():
188
+ result = result_queue.get()
189
+ yield result['image']
190
+ else:
191
+ # If no processed frame is available, yield the raw frame
192
+ yield frame
193
+
194
+ # Control frame rate to not overwhelm the system
195
+ time.sleep(0.01)
196
+ finally:
197
+ cap.release()
198
+
199
+ def process_uploaded_image(image):
200
+ """Process an uploaded image (this will be separate from real-time)"""
201
+ with processing_lock: # Acquire lock to pause real-time processing
202
+ return detect_objects(image)
203
+
204
+ # Setup Gradio interface
205
+ example_images = ["spring_street_after.jpg", "pexels-hikaique-109919.jpg"]
206
+ os.makedirs("examples", exist_ok=True)
207
+
208
+ with gr.Blocks(title="YOLOv5 Object Detection - Real-time & Image Upload") as demo:
209
  gr.Markdown("""
210
+ # YOLOv5 Object Detection
211
+ ## Real-time webcam detection and image upload processing
212
  """)
213
 
214
+ with gr.Tabs():
215
+ with gr.TabItem("Real-time Detection"):
216
+ gr.Markdown("""
217
+ ### Real-time Object Detection
218
+ Using your webcam for continuous object detection at 30+ FPS.
219
+ """)
220
+ webcam_output = gr.Image(label="Real-time Detection", type="numpy")
221
 
222
+ with gr.TabItem("Image Upload"):
223
+ gr.Markdown("""
224
+ ### Image Upload Detection
225
+ Upload an image to detect objects.
226
+ """)
227
+ with gr.Row():
228
+ with gr.Column(scale=1):
229
+ input_image = gr.Image(label="Input Image", type="numpy")
230
+ submit_button = gr.Button("Submit", variant="primary")
231
+ clear_button = gr.Button("Clear")
232
+
233
+ with gr.Column(scale=1):
234
+ output_image = gr.Image(label="Detected Objects", type="numpy")
235
+
236
+ gr.Examples(
237
+ examples=example_images,
238
+ inputs=input_image,
239
+ outputs=output_image,
240
+ fn=process_uploaded_image,
241
+ cache_examples=True
242
+ )
243
+
244
+ # Set up event handlers
245
+ submit_button.click(fn=process_uploaded_image, inputs=input_image, outputs=output_image)
246
+ clear_button.click(lambda: (None, None), None, [input_image, output_image])
247
 
248
+ # Connect webcam feed
249
+ demo.load(lambda: None, None, webcam_output, _js="""
250
+ () => {
251
+ // Keep the webcam tab refreshing at high frequency
252
+ setInterval(() => {
253
+ if (document.querySelector('.tabitem:first-child').style.display !== 'none') {
254
+ const webcamImg = document.querySelector('.tabitem:first-child img');
255
+ if (webcamImg) {
256
+ const src = webcamImg.src;
257
+ webcamImg.src = src.includes('?') ? src.split('?')[0] + '?t=' + Date.now() : src + '?t=' + Date.now();
258
+ }
259
+ }
260
+ }, 33); // ~30 FPS refresh rate
261
+ return [];
262
+ }
263
+ """)
264
+
265
+ # Start webcam feed
266
+ webcam_output.update(webcam_feed)
267
+
268
+ # Cleanup function to stop threads when app closes
269
+ def cleanup():
270
+ stop_event.set()
271
+ print("Cleaning up threads...")
272
 
273
+ demo.close = cleanup
274
+ demo.launch()