Aumkeshchy2003 commited on
Commit
054b852
·
verified ·
1 Parent(s): 6cfc8c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -50
app.py CHANGED
@@ -12,67 +12,124 @@ os.makedirs("models", exist_ok=True)
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print(f"Using device: {device}")
14
 
15
- # Load YOLOv5n model
16
  model_path = Path("models/yolov5n.pt")
17
- if not model_path.exists():
18
- torch.hub.download_url_to_file("https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5n.pt", "models/yolov5n.pt")
 
 
 
 
 
19
 
20
- model = torch.hub.load("ultralytics/yolov5", "custom", path=str(model_path)).to(device)
 
 
 
21
 
22
- # Model optimizations
23
- model.conf = 0.5
24
- model.iou = 0.45
25
  if device.type == "cuda":
26
- model.half()
27
  else:
28
- model.float()
29
- torch.set_num_threads(2)
30
 
31
  model.eval()
32
 
33
- colors = np.random.rand(80, 3) * 255 # COCO classes
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def detect_objects(image):
 
 
36
  if image is None:
37
  return None
38
-
39
- start = time.perf_counter()
40
-
41
- # Preprocess
42
- im = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
43
- im = cv2.resize(im, (320, 320))
44
- tensor = torch.from_numpy(im).to(device)
45
- tensor = tensor.half() if device.type == "cuda" else tensor.float()
46
- tensor = tensor.permute(2, 0, 1).unsqueeze(0) / 255
47
-
48
- # Inference
49
- with torch.no_grad():
50
- pred = model(tensor)[0]
51
-
52
- # Post-process
53
- pred = pred.float() if device.type == "cpu" else pred.half()
54
- pred = non_max_suppression(pred, model.conf, model.iou)[0]
55
-
56
- # Visualization
57
- output = image.copy()
58
- if pred is not None:
59
- pred[:, :4] = scale_coords(tensor.shape[2:], pred[:, :4], image.shape).round()
60
- for *xyxy, conf, cls in pred:
61
- x1, y1, x2, y2 = map(int, xyxy)
62
- cv2.rectangle(output, (x1, y1), (x2, y2), colors[int(cls)].tolist(), 2)
63
-
64
- # FPS counter
65
- fps = 1 / (time.perf_counter() - start)
66
- cv2.putText(output, f"FPS: {fps:.1f}", (10, 30),
67
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
68
-
69
- return output
70
-
71
- with gr.Blocks() as demo:
72
- gr.Markdown("# Real-Time Object Detection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  with gr.Row():
74
- input_img = gr.Image(label="Input", streaming=True) # Modified webcam handling
75
- output_img = gr.Image(label="Output")
76
- input_img.change(detect_objects, input_img, output_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- demo.launch()
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print(f"Using device: {device}")
14
 
15
+ # Use YOLOv5 Nano for better speed
16
  model_path = Path("models/yolov5n.pt")
17
+ if model_path.exists():
18
+ print(f"Loading model from cache: {model_path}")
19
+ model = torch.hub.load("ultralytics/yolov5", "custom", path=str(model_path), source="local").to(device)
20
+ else:
21
+ print("Downloading YOLOv5n model and caching...")
22
+ model = torch.hub.load("ultralytics/yolov5", "yolov5n", pretrained=True).to(device)
23
+ torch.save(model.state_dict(), model_path)
24
 
25
+ # Optimize model for speed
26
+ model.conf = 0.3 # Lower confidence threshold
27
+ model.iou = 0.3 # Non-Maximum Suppression IoU threshold
28
+ model.classes = None # Detect all classes
29
 
 
 
 
30
  if device.type == "cuda":
31
+ model.half() # Use FP16 for faster inference
32
  else:
33
+ torch.set_num_threads(os.cpu_count())
 
34
 
35
  model.eval()
36
 
37
+ # Pre-generate colors for bounding boxes
38
+ np.random.seed(42)
39
+ colors = np.random.uniform(0, 255, size=(len(model.names), 3))
40
+
41
+ # Track FPS
42
+ total_inference_time = 0
43
+ inference_count = 0
44
+
45
+ def preprocess_image(image):
46
+ """ Prepares image for YOLOv5 detection. """
47
+ input_size = 640
48
+ image = cv2.resize(image, (input_size, input_size))
49
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Convert to BGR for OpenCV
50
+ return image
51
 
52
  def detect_objects(image):
53
+ global total_inference_time, inference_count
54
+
55
  if image is None:
56
  return None
57
+
58
+ start_time = time.time()
59
+
60
+ # Preprocess image
61
+ image = preprocess_image(image)
62
+
63
+ with torch.inference_mode(): # Faster than torch.no_grad()
64
+ results = model(image, size=640)
65
+
66
+ inference_time = time.time() - start_time
67
+ total_inference_time += inference_time
68
+ inference_count += 1
69
+ avg_inference_time = total_inference_time / inference_count
70
+
71
+ detections = results.pred[0].cpu().numpy()
72
+
73
+ output_image = image.copy()
74
+
75
+ for *xyxy, conf, cls in detections:
76
+ x1, y1, x2, y2 = map(int, xyxy)
77
+ class_id = int(cls)
78
+ color = colors[class_id].tolist()
79
+
80
+ # Draw bounding box
81
+ cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 3, lineType=cv2.LINE_AA)
82
+
83
+ label = f"{model.names[class_id]} {conf:.2f}"
84
+ font_scale, font_thickness = 0.9, 2
85
+ (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)
86
+
87
+ # Label background
88
+ cv2.rectangle(output_image, (x1, y1 - h - 10), (x1 + w + 10, y1), color, -1)
89
+ cv2.putText(output_image, label, (x1 + 5, y1 - 5),
90
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), font_thickness, lineType=cv2.LINE_AA)
91
+
92
+ fps = 1 / inference_time
93
+
94
+ # Display FPS
95
+ overlay = output_image.copy()
96
+ cv2.rectangle(overlay, (10, 10), (300, 80), (0, 0, 0), -1)
97
+ output_image = cv2.addWeighted(overlay, 0.6, output_image, 0.4, 0)
98
+ cv2.putText(output_image, f"FPS: {fps:.2f}", (20, 40),
99
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, lineType=cv2.LINE_AA)
100
+ cv2.putText(output_image, f"Avg FPS: {1/avg_inference_time:.2f}", (20, 70),
101
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, lineType=cv2.LINE_AA)
102
+
103
+ return output_image
104
+
105
+ # Gradio UI
106
+ example_images = ["spring_street_after.jpg", "pexels-hikaique-109919.jpg"]
107
+ os.makedirs("examples", exist_ok=True)
108
+
109
+ with gr.Blocks(title="Optimized YOLOv5 Object Detection") as demo:
110
+ gr.Markdown("""
111
+ # Optimized YOLOv5 Object Detection
112
+ Detects objects using YOLOv5 with enhanced visualization and FPS tracking.
113
+ """)
114
+
115
  with gr.Row():
116
+ with gr.Column(scale=1):
117
+ input_image = gr.Image(label="Input Image", type="numpy")
118
+ submit_button = gr.Button("Submit", variant="primary")
119
+ clear_button = gr.Button("Clear")
120
+
121
+ with gr.Column(scale=1):
122
+ output_image = gr.Image(label="Detected Objects", type="numpy")
123
+
124
+ gr.Examples(
125
+ examples=example_images,
126
+ inputs=input_image,
127
+ outputs=output_image,
128
+ fn=detect_objects,
129
+ cache_examples=True
130
+ )
131
+
132
+ submit_button.click(fn=detect_objects, inputs=input_image, outputs=output_image)
133
+ clear_button.click(lambda: (None, None), None, [input_image, output_image])
134
 
135
+ demo.launch()