Aumkeshchy2003 commited on
Commit
9b2d010
·
verified ·
1 Parent(s): 151b93e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -74
app.py CHANGED
@@ -4,122 +4,108 @@ import gradio as gr
4
  import cv2
5
  import time
6
  import os
7
- import onnxruntime
8
  from pathlib import Path
9
- from ultralytics import YOLO
10
 
11
- # Load YOLOv5 model without AutoShape
12
- model = torch.hub.load("ultralytics/yolov5", "yolov5n", source="local")
13
 
14
- # Set device
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- model.to(device)
17
- model.eval()
18
 
19
- # Fuse layers for optimization
20
- model.fuse()
 
 
 
 
 
 
21
 
22
- # Export to ONNX format
23
- os.makedirs("models", exist_ok=True)
24
- model_path = Path("models/yolov5n.onnx")
 
25
 
26
- torch.onnx.export(
27
- model,
28
- torch.zeros(1, 3, 640, 640).to(device), # Input tensor
29
- str(model_path),
30
- export_params=True,
31
- opset_version=11,
32
- do_constant_folding=True,
33
- input_names=["images"],
34
- output_names=["output"],
35
- dynamic_axes={"images": {0: "batch_size"}, "output": {0: "batch_size"}}
36
- )
37
 
38
- # Load ONNX model for inference
39
- session = onnxruntime.InferenceSession(str(model_path), providers=['CUDAExecutionProvider'])
40
 
41
- # Generate random colors for each class
42
  np.random.seed(42)
43
- colors = np.random.uniform(0, 255, size=(80, 3))
44
 
45
  total_inference_time = 0
46
  inference_count = 0
47
 
48
  def detect_objects(image):
49
  global total_inference_time, inference_count
 
50
  if image is None:
51
  return None
52
-
53
- start_time = time.time()
54
-
55
- # Preprocess image
56
- original_shape = image.shape
57
- input_shape = (640, 640)
58
- image_resized = cv2.resize(image, input_shape)
59
- image_norm = image_resized.astype(np.float32) / 255.0
60
- image_transposed = np.transpose(image_norm, (2, 0, 1))
61
- image_batch = np.expand_dims(image_transposed, axis=0)
62
 
63
- # Get input name and run inference
64
- input_name = session.get_inputs()[0].name
65
- outputs = session.run(None, {input_name: image_batch})
66
 
67
- # Process detections
68
- detections = outputs[0][0] # First batch, all detections
69
 
70
- # Calculate timing
71
  inference_time = time.time() - start_time
72
  total_inference_time += inference_time
73
  inference_count += 1
74
  avg_inference_time = total_inference_time / inference_count
75
- fps = 1 / inference_time
76
 
77
- # Create a copy of the original image for visualization
78
- output_image = image.copy()
79
 
80
- # Scale factor for bounding box coordinates
81
- scale_x = original_shape[1] / input_shape[0]
82
- scale_y = original_shape[0] / input_shape[1]
83
-
84
- # Draw bounding boxes and labels
85
- for det in detections:
86
- x1, y1, x2, y2, conf, class_id = det[:6]
87
- if conf < 0.3: # Confidence threshold
88
- continue
89
-
90
- # Convert to original image coordinates
91
- x1, x2 = int(x1 * scale_x), int(x2 * scale_x)
92
- y1, y2 = int(y1 * scale_y), int(y2 * scale_y)
93
- class_id = int(class_id)
94
 
95
- # Draw rectangle and label
96
- color = tuple(map(int, colors[class_id]))
97
- cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2)
98
- label = f"Class {class_id} {conf:.2f}"
99
- cv2.putText(output_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
100
 
101
- # Display FPS
102
- cv2.putText(output_image, f"FPS: {fps:.2f}", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
103
- cv2.putText(output_image, f"Avg FPS: {1/avg_inference_time:.2f}", (20, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
 
 
 
 
 
 
 
104
 
105
  return output_image
106
 
107
- # Gradio Interface
108
  example_images = ["spring_street_after.jpg", "pexels-hikaique-109919.jpg"]
109
  os.makedirs("examples", exist_ok=True)
110
 
111
  with gr.Blocks(title="Optimized YOLOv5 Object Detection") as demo:
112
- gr.Markdown("# **Optimized YOLOv5 Object Detection** 🚀")
 
 
 
113
 
114
  with gr.Row():
115
  with gr.Column(scale=1):
116
  input_image = gr.Image(label="Input Image", type="numpy")
117
- submit_button = gr.Button("Detect Objects", variant="primary")
118
  clear_button = gr.Button("Clear")
119
-
120
  with gr.Column(scale=1):
121
  output_image = gr.Image(label="Detected Objects", type="numpy")
122
-
123
  gr.Examples(
124
  examples=example_images,
125
  inputs=input_image,
@@ -127,7 +113,7 @@ with gr.Blocks(title="Optimized YOLOv5 Object Detection") as demo:
127
  fn=detect_objects,
128
  cache_examples=True
129
  )
130
-
131
  submit_button.click(fn=detect_objects, inputs=input_image, outputs=output_image)
132
  clear_button.click(lambda: (None, None), None, [input_image, output_image])
133
 
 
4
  import cv2
5
  import time
6
  import os
 
7
  from pathlib import Path
 
8
 
9
+ # Create cache directory for models
10
+ os.makedirs("models", exist_ok=True)
11
 
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"Using device: {device}")
 
14
 
15
+ model_path = Path("models/yolov5n.pt")
16
+ if model_path.exists():
17
+ print(f"Loading model from cache: {model_path}")
18
+ model = torch.hub.load("ultralytics/yolov5", "yolov5x", pretrained=True, source="local", path=str(model_path)).to(device)
19
+ else:
20
+ print("Downloading YOLOv5n model and caching...")
21
+ model = torch.hub.load("ultralytics/yolov5", "yolov5x", pretrained=True).to(device)
22
+ torch.save(model.state_dict(), model_path)
23
 
24
+ # Model configurations
25
+ model.conf = 0.3
26
+ model.iou = 0.3
27
+ model.classes = None
28
 
29
+ if device.type == "cuda":
30
+ model.half()
31
+ else:
32
+ torch.set_num_threads(os.cpu_count())
 
 
 
 
 
 
 
33
 
34
+ model.eval()
 
35
 
 
36
  np.random.seed(42)
37
+ colors = np.random.uniform(0, 255, size=(len(model.names), 3))
38
 
39
  total_inference_time = 0
40
  inference_count = 0
41
 
42
  def detect_objects(image):
43
  global total_inference_time, inference_count
44
+
45
  if image is None:
46
  return None
 
 
 
 
 
 
 
 
 
 
47
 
48
+ start_time = time.time()
49
+ output_image = image.copy()
50
+ input_size = 640
51
 
52
+ with torch.no_grad():
53
+ results = model(image, size=input_size)
54
 
 
55
  inference_time = time.time() - start_time
56
  total_inference_time += inference_time
57
  inference_count += 1
58
  avg_inference_time = total_inference_time / inference_count
 
59
 
60
+ detections = results.pred[0].cpu().numpy()
 
61
 
62
+ for *xyxy, conf, cls in detections:
63
+ x1, y1, x2, y2 = map(int, xyxy)
64
+ class_id = int(cls)
65
+ color = colors[class_id].tolist()
66
+
67
+ # Thicker bounding boxes
68
+ cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 3, lineType=cv2.LINE_AA)
69
+
70
+ label = f"{model.names[class_id]} {conf:.2f}"
71
+ font_scale, font_thickness = 0.9, 2
72
+ (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)
 
 
 
73
 
74
+ cv2.rectangle(output_image, (x1, y1 - h - 10), (x1 + w + 10, y1), color, -1)
75
+ cv2.putText(output_image, label, (x1 + 5, y1 - 5),
76
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), font_thickness, lineType=cv2.LINE_AA)
 
 
77
 
78
+ fps = 1 / inference_time
79
+
80
+ # Stylish FPS display
81
+ overlay = output_image.copy()
82
+ cv2.rectangle(overlay, (10, 10), (300, 80), (0, 0, 0), -1)
83
+ output_image = cv2.addWeighted(overlay, 0.6, output_image, 0.4, 0)
84
+ cv2.putText(output_image, f"FPS: {fps:.2f}", (20, 40),
85
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, lineType=cv2.LINE_AA)
86
+ cv2.putText(output_image, f"Avg FPS: {1/avg_inference_time:.2f}", (20, 70),
87
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, lineType=cv2.LINE_AA)
88
 
89
  return output_image
90
 
 
91
  example_images = ["spring_street_after.jpg", "pexels-hikaique-109919.jpg"]
92
  os.makedirs("examples", exist_ok=True)
93
 
94
  with gr.Blocks(title="Optimized YOLOv5 Object Detection") as demo:
95
+ gr.Markdown("""
96
+ # Optimized YOLOv5 Object Detection
97
+ Detects objects using YOLOv5 with enhanced visualization and FPS tracking.
98
+ """)
99
 
100
  with gr.Row():
101
  with gr.Column(scale=1):
102
  input_image = gr.Image(label="Input Image", type="numpy")
103
+ submit_button = gr.Button("Submit", variant="primary")
104
  clear_button = gr.Button("Clear")
105
+
106
  with gr.Column(scale=1):
107
  output_image = gr.Image(label="Detected Objects", type="numpy")
108
+
109
  gr.Examples(
110
  examples=example_images,
111
  inputs=input_image,
 
113
  fn=detect_objects,
114
  cache_examples=True
115
  )
116
+
117
  submit_button.click(fn=detect_objects, inputs=input_image, outputs=output_image)
118
  clear_button.click(lambda: (None, None), None, [input_image, output_image])
119