Aumkeshchy2003's picture
Update app.py
d87db9b verified
raw
history blame
11.5 kB
import torch
import numpy as np
import gradio as gr
import cv2
import time
import os
from pathlib import Path
from PIL import Image
from threading import Thread
from queue import Queue
# Create cache directory for models
os.makedirs("models", exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load YOLOv5 Nano model
model_path = Path("models/yolov5n.pt")
if model_path.exists():
print(f"Loading model from cache: {model_path}")
model = torch.hub.load("ultralytics/yolov5", "custom", path=str(model_path), source="local").to(device)
else:
print("Downloading YOLOv5n model and caching...")
model = torch.hub.load("ultralytics/yolov5", "yolov5n", pretrained=True).to(device)
torch.save(model.state_dict(), model_path)
# Optimize model for speed
model.conf = 0.25 # Slightly lower confidence threshold
model.iou = 0.45 # Better IoU threshold
model.classes = None
model.max_det = 100 # Limit maximum detections
if device.type == "cuda":
model.half() # Use FP16 precision
else:
torch.set_num_threads(os.cpu_count())
model.eval()
# Pre-generate colors for bounding boxes
np.random.seed(42)
colors = np.random.randint(0, 255, size=(len(model.names), 3), dtype=np.uint8)
# Async video processing
def process_frame(model, frame_queue, result_queue):
while True:
if frame_queue.empty():
time.sleep(0.001)
continue
frame_data = frame_queue.get()
if frame_data is None: # Signal to stop
result_queue.put(None)
break
frame, frame_index = frame_data
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Use a smaller inference size for speed
results = model(img, size=384) # Reduced from 640 to 384
detections = results.xyxy[0].cpu().numpy()
result_queue.put((frame, detections, frame_index))
def process_video(video_path):
# Check if video_path is None or empty
if video_path is None or video_path == "":
return None
# Handle the case when Gradio passes a tuple (file, None)
if isinstance(video_path, tuple) and len(video_path) >= 1:
video_path = video_path[0]
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return "Error: Could not open video file."
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
fps = cap.get(cv2.CAP_PROP_FPS)
# Used h264 codec for better performance
fourcc = cv2.VideoWriter_fourcc(*'avc1')
output_path = "output_video.mp4"
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
# Created queues for async processing
frame_queue = Queue(maxsize=10)
result_queue = Queue()
# Start processing thread
processing_thread = Thread(target=process_frame, args=(model, frame_queue, result_queue))
processing_thread.daemon = True
processing_thread.start()
total_frames = 0
start_time = time.time()
processing_started = False
frames_buffer = {}
next_frame_to_write = 0
try:
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if not processing_started:
processing_started = True
start_time = time.time()
frame_queue.put((frame, total_frames))
total_frames += 1
# Process results if available
while not result_queue.empty():
result = result_queue.get()
if result is None:
break
processed_frame, detections, frame_idx = result
frames_buffer[frame_idx] = (processed_frame, detections)
# Write frames in order
while next_frame_to_write in frames_buffer:
buffer_frame, buffer_detections = frames_buffer.pop(next_frame_to_write)
# Draw bounding boxes
for *xyxy, conf, cls in buffer_detections:
if conf < 0.35: # Additional filtering
continue
x1, y1, x2, y2 = map(int, xyxy)
class_id = int(cls)
color = colors[class_id].tolist()
cv2.rectangle(buffer_frame, (x1, y1), (x2, y2), color, 2, lineType=cv2.LINE_AA)
label = f"{model.names[class_id]} {conf:.2f}"
# Black text with white outline for better visibility
cv2.putText(buffer_frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX,
0.7, (0, 0, 0), 2, cv2.LINE_AA)
# Calculate elapsed time and FPS
elapsed = time.time() - start_time
current_fps = next_frame_to_write / elapsed if elapsed > 0 else 0
# Add FPS counter with black text
cv2.putText(buffer_frame, f"FPS: {current_fps:.2f}", (20, 40),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)
out.write(buffer_frame)
next_frame_to_write += 1
# Signal thread to finish and process remaining frames
frame_queue.put(None)
# Process remaining buffered frames
while True:
if result_queue.empty():
time.sleep(0.01)
continue
result = result_queue.get()
if result is None:
break
processed_frame, detections, frame_idx = result
frames_buffer[frame_idx] = (processed_frame, detections)
# Write remaining frames in order
while next_frame_to_write in frames_buffer:
buffer_frame, buffer_detections = frames_buffer.pop(next_frame_to_write)
# Draw bounding boxes
for *xyxy, conf, cls in buffer_detections:
if conf < 0.35:
continue
x1, y1, x2, y2 = map(int, xyxy)
class_id = int(cls)
color = colors[class_id].tolist()
cv2.rectangle(buffer_frame, (x1, y1), (x2, y2), color, 2, lineType=cv2.LINE_AA)
label = f"{model.names[class_id]} {conf:.2f}"
cv2.putText(buffer_frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX,
0.7, (0, 0, 0), 2, cv2.LINE_AA)
# Add FPS counter
elapsed = time.time() - start_time
current_fps = next_frame_to_write / elapsed if elapsed > 0 else 0
cv2.putText(buffer_frame, f"FPS: {current_fps:.2f}", (20, 40),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)
out.write(buffer_frame)
next_frame_to_write += 1
finally:
cap.release()
out.release()
return output_path
def process_image(image):
if image is None:
return None
img = np.array(image)
# Process with smaller size for speed
results = model(img, size=512)
detections = results.pred[0].cpu().numpy()
for *xyxy, conf, cls in detections:
x1, y1, x2, y2 = map(int, xyxy)
class_id = int(cls)
color = colors[class_id].tolist()
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2, lineType=cv2.LINE_AA)
label = f"{model.names[class_id]} {conf:.2f}"
# Black text
cv2.putText(img, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2, cv2.LINE_AA)
return Image.fromarray(img)
css = """
#title {
text-align: center;
color: #2C3E50;
font-size: 2.5rem;
margin: 1.5rem 0;
text-shadow: 1px 1px 2px rgba(0,0,0,0.1);
}
.gradio-container {
background-color: #F5F7FA;
}
.tab-item {
background-color: white;
border-radius: 10px;
padding: 20px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
margin: 10px;
}
.button-row {
display: flex;
justify-content: space-around;
margin: 1rem 0;
}
#video-process-btn, #submit-btn {
background-color: #3498DB;
border: none;
}
#clear-btn {
background-color: #E74C3C;
border: none;
}
.output-container {
margin-top: 1.5rem;
border: 2px dashed #3498DB;
border-radius: 10px;
padding: 10px;
}
.footer {
text-align: center;
margin-top: 2rem;
font-size: 0.9rem;
color: #7F8C8D;
}
"""
with gr.Blocks(css=css, title="Video & Image Object Detection by YOLOv5") as demo:
gr.Markdown("""# YOLOv5 Object Detection""", elem_id="title")
with gr.Tabs():
with gr.TabItem("Video Detection", elem_classes="tab-item"):
with gr.Row():
video_input = gr.Video(
label="Upload Video",
interactive=True,
elem_id="video-input"
)
with gr.Row(elem_classes="button-row"):
process_button = gr.Button(
"Process Video",
variant="primary",
elem_id="video-process-btn"
)
with gr.Row(elem_classes="output-container"):
video_output = gr.Video(
label="Processed Video",
elem_id="video-output"
)
process_button.click(
fn=process_video,
inputs=video_input,
outputs=video_output
)
with gr.TabItem("Image Detection", elem_classes="tab-item"):
with gr.Row():
image_input = gr.Image(
type="pil",
label="Upload Image",
interactive=True
)
with gr.Row(elem_classes="button-row"):
clear_button = gr.Button(
"Clear",
variant="secondary",
elem_id="clear-btn"
)
submit_button = gr.Button(
"Detect Objects",
variant="primary",
elem_id="submit-btn"
)
with gr.Row(elem_classes="output-container"):
image_output = gr.Image(
label="Detected Objects",
elem_id="image-output"
)
clear_button.click(
fn=lambda: None,
inputs=None,
outputs=image_output
)
submit_button.click(
fn=process_image,
inputs=image_input,
outputs=image_output
)
gr.Markdown("""
### Powered by YOLOv5.
This application enables seamless object detection using the YOLOv5 model, allowing users to analyze images and videos with high accuracy and efficiency.
""", elem_classes="footer")
if __name__ == "__main__":
demo.launch()