|
import gradio as gr |
|
from ultralytics import YOLO |
|
import cv2 |
|
import tempfile |
|
import time |
|
import numpy as np |
|
|
|
|
|
def load_model(model_file): |
|
try: |
|
model = YOLO(model_file.name) |
|
return model |
|
except Exception as e: |
|
return f"Error loading model: {e}" |
|
|
|
|
|
def predict_image(model, image, conf): |
|
try: |
|
start_time = time.time() |
|
|
|
results = model(image, conf=conf) |
|
process_time = time.time() - start_time |
|
|
|
|
|
annotated_frame = results[0].plot() |
|
|
|
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
num_detections = len(results[0].boxes) if hasattr(results[0], "boxes") else "N/A" |
|
return annotated_frame, process_time, num_detections |
|
except Exception as e: |
|
return f"Error during image inference: {e}", None, None |
|
|
|
|
|
def predict_video(model, video_file, conf, frame_step): |
|
try: |
|
cap = cv2.VideoCapture(video_file.name) |
|
frames = [] |
|
frame_count = 0 |
|
start_time = time.time() |
|
|
|
while True: |
|
success, frame = cap.read() |
|
if not success: |
|
break |
|
|
|
|
|
if frame_count % frame_step == 0: |
|
results = model(frame, conf=conf) |
|
annotated_frame = results[0].plot() |
|
frames.append(annotated_frame) |
|
else: |
|
|
|
frames.append(frame) |
|
frame_count += 1 |
|
|
|
process_time = time.time() - start_time |
|
cap.release() |
|
|
|
if not frames: |
|
return "Error: No frames processed", None, None |
|
|
|
height, width, _ = frames[0].shape |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) |
|
out = cv2.VideoWriter(tmp.name, fourcc, 20.0, (width, height)) |
|
for frame in frames: |
|
out.write(frame) |
|
out.release() |
|
|
|
|
|
num_detections = "See individual frames" |
|
return tmp.name, process_time, num_detections |
|
except Exception as e: |
|
return f"Error during video inference: {e}", None, None |
|
|
|
|
|
|
|
|
|
def inference(model_file, input_media, media_type, conf, frame_step): |
|
model = load_model(model_file) |
|
if isinstance(model, str): |
|
return model, None, {"processing_time": None, "detections": None} |
|
|
|
if media_type == "Image": |
|
out_img, process_time, detections = predict_image(model, input_media, conf) |
|
metadata = {"processing_time": process_time, "detections": detections} |
|
return out_img, None, metadata |
|
|
|
elif media_type == "Video": |
|
out_vid, process_time, detections = predict_video(model, input_media, conf, frame_step) |
|
metadata = {"processing_time": process_time, "detections": detections} |
|
return None, out_vid, metadata |
|
else: |
|
return "Unsupported media type", None, {"processing_time": None, "detections": None} |
|
|
|
|
|
|
|
model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)") |
|
|
|
|
|
media_file_input = gr.File(label="Upload Image/Video File") |
|
|
|
|
|
media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image") |
|
|
|
|
|
confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5, label="Detection Confidence Threshold") |
|
|
|
|
|
frame_step_slider = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Frame Step (for Video Processing)") |
|
|
|
|
|
|
|
|
|
|
|
output_image = gr.Image(label="Annotated Image") |
|
output_video = gr.Video(label="Annotated Video") |
|
output_metadata = gr.JSON(label="Metadata") |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=inference, |
|
inputs=[model_file_input, media_file_input, media_type_dropdown, confidence_slider, frame_step_slider], |
|
outputs=[output_image, output_video, output_metadata], |
|
title="Custom YOLO Model Inference for Real-Time Detection", |
|
description=( |
|
"Upload your custom YOLO model (detection, segmentation, or OBB) along with an image or video file " |
|
"to run inference. Adjust the detection confidence and frame step (for video) as needed. " |
|
"The app shows the processed image/video and returns metadata for real-time API integration. " |
|
"This is optimized for users who wish to host a YOLO model on Hugging Face and use it for real-time " |
|
"object detection via the Gradio API." |
|
) |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|