yolohost / app.py
wuhp's picture
Update app.py
cb2f7e3 verified
import gradio as gr
from ultralytics import YOLO
import cv2
import tempfile
import time
import numpy as np
# Load a custom YOLO model from the uploaded file.
def load_model(model_file):
try:
model = YOLO(model_file.name)
return model
except Exception as e:
return f"Error loading model: {e}"
# Run inference on an image and return a processed image as an np.ndarray.
def predict_image(model, image, conf):
try:
start_time = time.time()
# Run inference with confidence threshold.
results = model(image, conf=conf)
process_time = time.time() - start_time
# Get the annotated image using the model's built-in plotting.
annotated_frame = results[0].plot()
# Optional: Convert BGR (OpenCV default) to RGB if needed.
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
# Count detections if available (assumes results[0].boxes exists).
num_detections = len(results[0].boxes) if hasattr(results[0], "boxes") else "N/A"
return annotated_frame, process_time, num_detections
except Exception as e:
return f"Error during image inference: {e}", None, None
# Run inference on a video by processing selected frames and return a processed video file.
def predict_video(model, video_file, conf, frame_step):
try:
cap = cv2.VideoCapture(video_file.name)
frames = []
frame_count = 0
start_time = time.time()
while True:
success, frame = cap.read()
if not success:
break
# Only process every nth frame determined by frame_step.
if frame_count % frame_step == 0:
results = model(frame, conf=conf)
annotated_frame = results[0].plot()
frames.append(annotated_frame)
else:
# Optionally, append the original frame, or skip entirely.
frames.append(frame)
frame_count += 1
process_time = time.time() - start_time
cap.release()
if not frames:
return "Error: No frames processed", None, None
height, width, _ = frames[0].shape
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
out = cv2.VideoWriter(tmp.name, fourcc, 20.0, (width, height))
for frame in frames:
out.write(frame)
out.release()
# For video, we return a placeholder for number of detections. (More logic can be added to aggregate detections.)
num_detections = "See individual frames"
return tmp.name, process_time, num_detections
except Exception as e:
return f"Error during video inference: {e}", None, None
# Main inference function.
# Returns a tuple: (annotated_image, annotated_video, metadata)
# For image inputs, the video output is None; for video inputs, the image output is None.
def inference(model_file, input_media, media_type, conf, frame_step):
model = load_model(model_file)
if isinstance(model, str): # This indicates an error during model loading.
return model, None, {"processing_time": None, "detections": None}
if media_type == "Image":
out_img, process_time, detections = predict_image(model, input_media, conf)
metadata = {"processing_time": process_time, "detections": detections}
return out_img, None, metadata
elif media_type == "Video":
out_vid, process_time, detections = predict_video(model, input_media, conf, frame_step)
metadata = {"processing_time": process_time, "detections": detections}
return None, out_vid, metadata
else:
return "Unsupported media type", None, {"processing_time": None, "detections": None}
# Define Gradio interface components.
# Component for uploading a custom YOLO model (.pt file).
model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
# Component for uploading an image or video.
media_file_input = gr.File(label="Upload Image/Video File")
# Radio button to choose media type.
media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
# Detection confidence slider.
confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5, label="Detection Confidence Threshold")
# Frame step slider (for video processing).
frame_step_slider = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Frame Step (for Video Processing)")
# For display on the site:
# - Use gr.Image to display the processed image.
# - Use gr.Video to display the processed video.
# - Use gr.JSON to display the metadata.
output_image = gr.Image(label="Annotated Image")
output_video = gr.Video(label="Annotated Video")
output_metadata = gr.JSON(label="Metadata")
# Create the Gradio interface.
# Note: The function returns a triple: (processed image, processed video, metadata).
iface = gr.Interface(
fn=inference,
inputs=[model_file_input, media_file_input, media_type_dropdown, confidence_slider, frame_step_slider],
outputs=[output_image, output_video, output_metadata],
title="Custom YOLO Model Inference for Real-Time Detection",
description=(
"Upload your custom YOLO model (detection, segmentation, or OBB) along with an image or video file "
"to run inference. Adjust the detection confidence and frame step (for video) as needed. "
"The app shows the processed image/video and returns metadata for real-time API integration. "
"This is optimized for users who wish to host a YOLO model on Hugging Face and use it for real-time "
"object detection via the Gradio API."
)
)
if __name__ == "__main__":
iface.launch()