Update app.py
Browse files
app.py
CHANGED
@@ -3,25 +3,28 @@ from ultralytics import YOLO
|
|
3 |
import cv2
|
4 |
import tempfile
|
5 |
|
6 |
-
#
|
7 |
def load_model(model_file):
|
8 |
try:
|
9 |
-
# model_file is a TemporaryFile object. Use .name to get its path.
|
10 |
model = YOLO(model_file.name)
|
11 |
return model
|
12 |
except Exception as e:
|
13 |
return f"Error loading model: {e}"
|
14 |
|
15 |
-
#
|
16 |
def predict_image(model, image):
|
17 |
try:
|
18 |
results = model(image)
|
19 |
-
annotated_frame = results[0].plot() #
|
20 |
-
|
|
|
|
|
|
|
21 |
except Exception as e:
|
22 |
return f"Error during image inference: {e}"
|
23 |
|
24 |
-
#
|
|
|
25 |
def predict_video(model, video_file):
|
26 |
try:
|
27 |
cap = cv2.VideoCapture(video_file.name)
|
@@ -33,53 +36,55 @@ def predict_video(model, video_file):
|
|
33 |
frames.append(annotated_frame)
|
34 |
success, frame = cap.read()
|
35 |
cap.release()
|
36 |
-
|
37 |
if not frames:
|
38 |
-
return "Error: No frames processed
|
39 |
-
|
40 |
height, width, _ = frames[0].shape
|
41 |
-
fourcc = cv2.VideoWriter_fourcc(*
|
42 |
-
|
43 |
-
out = cv2.VideoWriter(
|
44 |
for frame in frames:
|
45 |
out.write(frame)
|
46 |
out.release()
|
47 |
-
return
|
48 |
except Exception as e:
|
49 |
return f"Error during video inference: {e}"
|
50 |
|
51 |
-
#
|
|
|
|
|
52 |
def inference(model_file, input_media, media_type):
|
53 |
model = load_model(model_file)
|
54 |
-
# Check if model loading resulted in an error message.
|
55 |
if isinstance(model, str):
|
56 |
-
|
57 |
-
|
|
|
58 |
if media_type == "Image":
|
59 |
-
|
|
|
60 |
elif media_type == "Video":
|
61 |
-
|
|
|
62 |
else:
|
63 |
-
return "Unsupported media type
|
64 |
|
65 |
-
#
|
66 |
-
# - A file input for the custom YOLO model (.pt file)
|
67 |
-
# - A file input for the image or video to process
|
68 |
-
# - A radio button for selecting between image and video processing.
|
69 |
model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
|
70 |
media_file_input = gr.File(label="Upload Image/Video File")
|
71 |
media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
|
72 |
-
output_component = gr.File(label="Processed Output")
|
73 |
|
74 |
-
#
|
|
|
|
|
|
|
|
|
75 |
iface = gr.Interface(
|
76 |
fn=inference,
|
77 |
inputs=[model_file_input, media_file_input, media_type_dropdown],
|
78 |
-
outputs=
|
79 |
title="Custom YOLO Model Inference",
|
80 |
description=(
|
81 |
-
"Upload your custom YOLO model (
|
82 |
-
"to run inference. The system
|
83 |
)
|
84 |
)
|
85 |
|
|
|
3 |
import cv2
|
4 |
import tempfile
|
5 |
|
6 |
+
# Load a custom YOLO model from the uploaded file.
|
7 |
def load_model(model_file):
|
8 |
try:
|
|
|
9 |
model = YOLO(model_file.name)
|
10 |
return model
|
11 |
except Exception as e:
|
12 |
return f"Error loading model: {e}"
|
13 |
|
14 |
+
# Run inference on an image and write the output to a PNG file.
|
15 |
def predict_image(model, image):
|
16 |
try:
|
17 |
results = model(image)
|
18 |
+
annotated_frame = results[0].plot() # Works for detection, segmentation, and OBB models.
|
19 |
+
# Write annotated image to a temporary file.
|
20 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
21 |
+
cv2.imwrite(tmp.name, annotated_frame)
|
22 |
+
return tmp.name
|
23 |
except Exception as e:
|
24 |
return f"Error during image inference: {e}"
|
25 |
|
26 |
+
# Run inference on a video by processing frame-by-frame,
|
27 |
+
# and write the annotated video to an MP4 file.
|
28 |
def predict_video(model, video_file):
|
29 |
try:
|
30 |
cap = cv2.VideoCapture(video_file.name)
|
|
|
36 |
frames.append(annotated_frame)
|
37 |
success, frame = cap.read()
|
38 |
cap.release()
|
|
|
39 |
if not frames:
|
40 |
+
return f"Error: No frames processed"
|
|
|
41 |
height, width, _ = frames[0].shape
|
42 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
43 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
44 |
+
out = cv2.VideoWriter(tmp.name, fourcc, 20.0, (width, height))
|
45 |
for frame in frames:
|
46 |
out.write(frame)
|
47 |
out.release()
|
48 |
+
return tmp.name
|
49 |
except Exception as e:
|
50 |
return f"Error during video inference: {e}"
|
51 |
|
52 |
+
# Main inference function: loads the custom model and processes the input media.
|
53 |
+
# Returns a tuple: (annotated_image, annotated_video).
|
54 |
+
# One element will be a file path and the other None, based on the media type.
|
55 |
def inference(model_file, input_media, media_type):
|
56 |
model = load_model(model_file)
|
|
|
57 |
if isinstance(model, str):
|
58 |
+
# An error occurred during model loading.
|
59 |
+
return (model, None)
|
60 |
+
|
61 |
if media_type == "Image":
|
62 |
+
out_image = predict_image(model, input_media)
|
63 |
+
return (out_image, None)
|
64 |
elif media_type == "Video":
|
65 |
+
out_video = predict_video(model, input_media)
|
66 |
+
return (None, out_video)
|
67 |
else:
|
68 |
+
return ("Unsupported media type", None)
|
69 |
|
70 |
+
# Define Gradio interface components.
|
|
|
|
|
|
|
71 |
model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
|
72 |
media_file_input = gr.File(label="Upload Image/Video File")
|
73 |
media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
|
|
|
74 |
|
75 |
+
# Define two outputs: one for images and one for videos.
|
76 |
+
output_image = gr.Image(label="Annotated Image")
|
77 |
+
output_video = gr.Video(label="Annotated Video")
|
78 |
+
|
79 |
+
# Create a Gradio interface that returns a tuple: (image, video).
|
80 |
iface = gr.Interface(
|
81 |
fn=inference,
|
82 |
inputs=[model_file_input, media_file_input, media_type_dropdown],
|
83 |
+
outputs=[output_image, output_video],
|
84 |
title="Custom YOLO Model Inference",
|
85 |
description=(
|
86 |
+
"Upload your custom YOLO model (detection, segmentation, or OBB) along with an image or video file "
|
87 |
+
"to run inference. The system loads your model dynamically, processes the media, and displays the output."
|
88 |
)
|
89 |
)
|
90 |
|