import gradio as gr from ultralytics import YOLO import cv2 import tempfile import time import numpy as np # Load a custom YOLO model from the uploaded file. def load_model(model_file): try: model = YOLO(model_file.name) return model except Exception as e: return f"Error loading model: {e}" # Run inference on an image and return a processed image as an np.ndarray. def predict_image(model, image, conf): try: start_time = time.time() # Run inference with confidence threshold. results = model(image, conf=conf) process_time = time.time() - start_time # Get the annotated image using the model's built-in plotting. annotated_frame = results[0].plot() # Optional: Convert BGR (OpenCV default) to RGB if needed. annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) # Count detections if available (assumes results[0].boxes exists). num_detections = len(results[0].boxes) if hasattr(results[0], "boxes") else "N/A" return annotated_frame, process_time, num_detections except Exception as e: return f"Error during image inference: {e}", None, None # Run inference on a video by processing selected frames and return a processed video file. def predict_video(model, video_file, conf, frame_step): try: cap = cv2.VideoCapture(video_file.name) frames = [] frame_count = 0 start_time = time.time() while True: success, frame = cap.read() if not success: break # Only process every nth frame determined by frame_step. if frame_count % frame_step == 0: results = model(frame, conf=conf) annotated_frame = results[0].plot() frames.append(annotated_frame) else: # Optionally, append the original frame, or skip entirely. frames.append(frame) frame_count += 1 process_time = time.time() - start_time cap.release() if not frames: return "Error: No frames processed", None, None height, width, _ = frames[0].shape fourcc = cv2.VideoWriter_fourcc(*"mp4v") tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) out = cv2.VideoWriter(tmp.name, fourcc, 20.0, (width, height)) for frame in frames: out.write(frame) out.release() # For video, we return a placeholder for number of detections. (More logic can be added to aggregate detections.) num_detections = "See individual frames" return tmp.name, process_time, num_detections except Exception as e: return f"Error during video inference: {e}", None, None # Main inference function. # Returns a tuple: (annotated_image, annotated_video, metadata) # For image inputs, the video output is None; for video inputs, the image output is None. def inference(model_file, input_media, media_type, conf, frame_step): model = load_model(model_file) if isinstance(model, str): # This indicates an error during model loading. return model, None, {"processing_time": None, "detections": None} if media_type == "Image": out_img, process_time, detections = predict_image(model, input_media, conf) metadata = {"processing_time": process_time, "detections": detections} return out_img, None, metadata elif media_type == "Video": out_vid, process_time, detections = predict_video(model, input_media, conf, frame_step) metadata = {"processing_time": process_time, "detections": detections} return None, out_vid, metadata else: return "Unsupported media type", None, {"processing_time": None, "detections": None} # Define Gradio interface components. # Component for uploading a custom YOLO model (.pt file). model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)") # Component for uploading an image or video. media_file_input = gr.File(label="Upload Image/Video File") # Radio button to choose media type. media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image") # Detection confidence slider. confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5, label="Detection Confidence Threshold") # Frame step slider (for video processing). frame_step_slider = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Frame Step (for Video Processing)") # For display on the site: # - Use gr.Image to display the processed image. # - Use gr.Video to display the processed video. # - Use gr.JSON to display the metadata. output_image = gr.Image(label="Annotated Image") output_video = gr.Video(label="Annotated Video") output_metadata = gr.JSON(label="Metadata") # Create the Gradio interface. # Note: The function returns a triple: (processed image, processed video, metadata). iface = gr.Interface( fn=inference, inputs=[model_file_input, media_file_input, media_type_dropdown, confidence_slider, frame_step_slider], outputs=[output_image, output_video, output_metadata], title="Custom YOLO Model Inference for Real-Time Detection", description=( "Upload your custom YOLO model (detection, segmentation, or OBB) along with an image or video file " "to run inference. Adjust the detection confidence and frame step (for video) as needed. " "The app shows the processed image/video and returns metadata for real-time API integration. " "This is optimized for users who wish to host a YOLO model on Hugging Face and use it for real-time " "object detection via the Gradio API." ) ) if __name__ == "__main__": iface.launch()