File size: 5,789 Bytes
eb25988 a32514b eb25988 cb2f7e3 651f077 eb25988 651f077 eb25988 cb2f7e3 a32514b eb25988 a32514b cb2f7e3 a32514b cb2f7e3 a32514b cb2f7e3 a32514b cb2f7e3 eb25988 a32514b eb25988 cb2f7e3 a32514b eb25988 a32514b eb25988 a32514b cb2f7e3 a32514b cb2f7e3 a32514b eb25988 a32514b eb25988 cb2f7e3 a32514b eb25988 5454bc6 eb25988 a32514b cb2f7e3 a32514b eb25988 a32514b eb25988 a32514b cb2f7e3 a32514b 651f077 cb2f7e3 a32514b eb25988 cb2f7e3 a32514b cb2f7e3 a32514b eb25988 cb2f7e3 a32514b cb2f7e3 eb25988 cb2f7e3 eb25988 5454bc6 cb2f7e3 651f077 a32514b cb2f7e3 651f077 a32514b cb2f7e3 eb25988 cb2f7e3 a32514b cb2f7e3 a32514b cb2f7e3 a32514b 5454bc6 a32514b cb2f7e3 eb25988 a32514b cb2f7e3 eb25988 cb2f7e3 eb25988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
from ultralytics import YOLO
import cv2
import tempfile
import time
import numpy as np
# Load a custom YOLO model from the uploaded file.
def load_model(model_file):
try:
model = YOLO(model_file.name)
return model
except Exception as e:
return f"Error loading model: {e}"
# Run inference on an image and return a processed image as an np.ndarray.
def predict_image(model, image, conf):
try:
start_time = time.time()
# Run inference with confidence threshold.
results = model(image, conf=conf)
process_time = time.time() - start_time
# Get the annotated image using the model's built-in plotting.
annotated_frame = results[0].plot()
# Optional: Convert BGR (OpenCV default) to RGB if needed.
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
# Count detections if available (assumes results[0].boxes exists).
num_detections = len(results[0].boxes) if hasattr(results[0], "boxes") else "N/A"
return annotated_frame, process_time, num_detections
except Exception as e:
return f"Error during image inference: {e}", None, None
# Run inference on a video by processing selected frames and return a processed video file.
def predict_video(model, video_file, conf, frame_step):
try:
cap = cv2.VideoCapture(video_file.name)
frames = []
frame_count = 0
start_time = time.time()
while True:
success, frame = cap.read()
if not success:
break
# Only process every nth frame determined by frame_step.
if frame_count % frame_step == 0:
results = model(frame, conf=conf)
annotated_frame = results[0].plot()
frames.append(annotated_frame)
else:
# Optionally, append the original frame, or skip entirely.
frames.append(frame)
frame_count += 1
process_time = time.time() - start_time
cap.release()
if not frames:
return "Error: No frames processed", None, None
height, width, _ = frames[0].shape
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
out = cv2.VideoWriter(tmp.name, fourcc, 20.0, (width, height))
for frame in frames:
out.write(frame)
out.release()
# For video, we return a placeholder for number of detections. (More logic can be added to aggregate detections.)
num_detections = "See individual frames"
return tmp.name, process_time, num_detections
except Exception as e:
return f"Error during video inference: {e}", None, None
# Main inference function.
# Returns a tuple: (annotated_image, annotated_video, metadata)
# For image inputs, the video output is None; for video inputs, the image output is None.
def inference(model_file, input_media, media_type, conf, frame_step):
model = load_model(model_file)
if isinstance(model, str): # This indicates an error during model loading.
return model, None, {"processing_time": None, "detections": None}
if media_type == "Image":
out_img, process_time, detections = predict_image(model, input_media, conf)
metadata = {"processing_time": process_time, "detections": detections}
return out_img, None, metadata
elif media_type == "Video":
out_vid, process_time, detections = predict_video(model, input_media, conf, frame_step)
metadata = {"processing_time": process_time, "detections": detections}
return None, out_vid, metadata
else:
return "Unsupported media type", None, {"processing_time": None, "detections": None}
# Define Gradio interface components.
# Component for uploading a custom YOLO model (.pt file).
model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
# Component for uploading an image or video.
media_file_input = gr.File(label="Upload Image/Video File")
# Radio button to choose media type.
media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
# Detection confidence slider.
confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5, label="Detection Confidence Threshold")
# Frame step slider (for video processing).
frame_step_slider = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Frame Step (for Video Processing)")
# For display on the site:
# - Use gr.Image to display the processed image.
# - Use gr.Video to display the processed video.
# - Use gr.JSON to display the metadata.
output_image = gr.Image(label="Annotated Image")
output_video = gr.Video(label="Annotated Video")
output_metadata = gr.JSON(label="Metadata")
# Create the Gradio interface.
# Note: The function returns a triple: (processed image, processed video, metadata).
iface = gr.Interface(
fn=inference,
inputs=[model_file_input, media_file_input, media_type_dropdown, confidence_slider, frame_step_slider],
outputs=[output_image, output_video, output_metadata],
title="Custom YOLO Model Inference for Real-Time Detection",
description=(
"Upload your custom YOLO model (detection, segmentation, or OBB) along with an image or video file "
"to run inference. Adjust the detection confidence and frame step (for video) as needed. "
"The app shows the processed image/video and returns metadata for real-time API integration. "
"This is optimized for users who wish to host a YOLO model on Hugging Face and use it for real-time "
"object detection via the Gradio API."
)
)
if __name__ == "__main__":
iface.launch()
|