File size: 5,789 Bytes
eb25988
 
 
 
a32514b
 
eb25988
cb2f7e3
651f077
eb25988
651f077
eb25988
 
 
 
cb2f7e3
a32514b
eb25988
a32514b
cb2f7e3
a32514b
 
 
cb2f7e3
a32514b
cb2f7e3
 
a32514b
cb2f7e3
 
 
eb25988
a32514b
eb25988
cb2f7e3
a32514b
eb25988
 
 
a32514b
 
 
 
eb25988
a32514b
 
 
cb2f7e3
a32514b
 
 
 
 
cb2f7e3
a32514b
 
 
 
eb25988
a32514b
eb25988
cb2f7e3
a32514b
eb25988
5454bc6
 
 
eb25988
 
 
a32514b
cb2f7e3
 
a32514b
eb25988
a32514b
eb25988
a32514b
cb2f7e3
 
a32514b
651f077
cb2f7e3
 
a32514b
eb25988
cb2f7e3
a32514b
cb2f7e3
a32514b
eb25988
cb2f7e3
a32514b
cb2f7e3
eb25988
cb2f7e3
eb25988
5454bc6
cb2f7e3
651f077
a32514b
cb2f7e3
651f077
a32514b
cb2f7e3
eb25988
 
cb2f7e3
a32514b
 
cb2f7e3
a32514b
 
cb2f7e3
 
 
 
 
 
a32514b
5454bc6
a32514b
cb2f7e3
eb25988
 
a32514b
cb2f7e3
 
eb25988
cb2f7e3
 
 
 
 
eb25988
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
from ultralytics import YOLO
import cv2
import tempfile
import time
import numpy as np

# Load a custom YOLO model from the uploaded file.
def load_model(model_file):
    try:
        model = YOLO(model_file.name)
        return model
    except Exception as e:
        return f"Error loading model: {e}"

# Run inference on an image and return a processed image as an np.ndarray.
def predict_image(model, image, conf):
    try:
        start_time = time.time()
        # Run inference with confidence threshold.
        results = model(image, conf=conf)
        process_time = time.time() - start_time

        # Get the annotated image using the model's built-in plotting.
        annotated_frame = results[0].plot()
        # Optional: Convert BGR (OpenCV default) to RGB if needed.
        annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)

        # Count detections if available (assumes results[0].boxes exists).
        num_detections = len(results[0].boxes) if hasattr(results[0], "boxes") else "N/A"
        return annotated_frame, process_time, num_detections
    except Exception as e:
        return f"Error during image inference: {e}", None, None

# Run inference on a video by processing selected frames and return a processed video file.
def predict_video(model, video_file, conf, frame_step):
    try:
        cap = cv2.VideoCapture(video_file.name)
        frames = []
        frame_count = 0
        start_time = time.time()

        while True:
            success, frame = cap.read()
            if not success:
                break

            # Only process every nth frame determined by frame_step.
            if frame_count % frame_step == 0:
                results = model(frame, conf=conf)
                annotated_frame = results[0].plot()
                frames.append(annotated_frame)
            else:
                # Optionally, append the original frame, or skip entirely.
                frames.append(frame)
            frame_count += 1

        process_time = time.time() - start_time
        cap.release()

        if not frames:
            return "Error: No frames processed", None, None

        height, width, _ = frames[0].shape
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
        out = cv2.VideoWriter(tmp.name, fourcc, 20.0, (width, height))
        for frame in frames:
            out.write(frame)
        out.release()

        # For video, we return a placeholder for number of detections. (More logic can be added to aggregate detections.)
        num_detections = "See individual frames"
        return tmp.name, process_time, num_detections
    except Exception as e:
        return f"Error during video inference: {e}", None, None

# Main inference function.
# Returns a tuple: (annotated_image, annotated_video, metadata)
# For image inputs, the video output is None; for video inputs, the image output is None.
def inference(model_file, input_media, media_type, conf, frame_step):
    model = load_model(model_file)
    if isinstance(model, str):  # This indicates an error during model loading.
        return model, None, {"processing_time": None, "detections": None}

    if media_type == "Image":
        out_img, process_time, detections = predict_image(model, input_media, conf)
        metadata = {"processing_time": process_time, "detections": detections}
        return out_img, None, metadata

    elif media_type == "Video":
        out_vid, process_time, detections = predict_video(model, input_media, conf, frame_step)
        metadata = {"processing_time": process_time, "detections": detections}
        return None, out_vid, metadata
    else:
        return "Unsupported media type", None, {"processing_time": None, "detections": None}

# Define Gradio interface components.
# Component for uploading a custom YOLO model (.pt file).
model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")

# Component for uploading an image or video.
media_file_input = gr.File(label="Upload Image/Video File")

# Radio button to choose media type.
media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")

# Detection confidence slider.
confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5, label="Detection Confidence Threshold")

# Frame step slider (for video processing).
frame_step_slider = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Frame Step (for Video Processing)")

# For display on the site:
# - Use gr.Image to display the processed image.
# - Use gr.Video to display the processed video.
# - Use gr.JSON to display the metadata.
output_image = gr.Image(label="Annotated Image")
output_video = gr.Video(label="Annotated Video")
output_metadata = gr.JSON(label="Metadata")

# Create the Gradio interface.
# Note: The function returns a triple: (processed image, processed video, metadata).
iface = gr.Interface(
    fn=inference,
    inputs=[model_file_input, media_file_input, media_type_dropdown, confidence_slider, frame_step_slider],
    outputs=[output_image, output_video, output_metadata],
    title="Custom YOLO Model Inference for Real-Time Detection",
    description=(
        "Upload your custom YOLO model (detection, segmentation, or OBB) along with an image or video file "
        "to run inference. Adjust the detection confidence and frame step (for video) as needed. "
        "The app shows the processed image/video and returns metadata for real-time API integration. "
        "This is optimized for users who wish to host a YOLO model on Hugging Face and use it for real-time "
        "object detection via the Gradio API."
    )
)

if __name__ == "__main__":
    iface.launch()