wuhp commited on
Commit
a32514b
·
verified ·
1 Parent(s): 5454bc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -41
app.py CHANGED
@@ -2,8 +2,10 @@ import gradio as gr
2
  from ultralytics import YOLO
3
  import cv2
4
  import tempfile
 
 
5
 
6
- # Load a custom YOLO model from the uploaded file.
7
  def load_model(model_file):
8
  try:
9
  model = YOLO(model_file.name)
@@ -11,33 +13,55 @@ def load_model(model_file):
11
  except Exception as e:
12
  return f"Error loading model: {e}"
13
 
14
- # Run inference on an image and write the output to a PNG file.
15
- def predict_image(model, image):
16
  try:
17
- results = model(image)
18
- annotated_frame = results[0].plot() # Works for detection, segmentation, and OBB models.
19
- # Write annotated image to a temporary file.
 
 
 
 
 
 
 
 
20
  tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
21
  cv2.imwrite(tmp.name, annotated_frame)
22
- return tmp.name
23
  except Exception as e:
24
- return f"Error during image inference: {e}"
25
 
26
- # Run inference on a video by processing frame-by-frame,
27
- # and write the annotated video to an MP4 file.
28
- def predict_video(model, video_file):
29
  try:
30
  cap = cv2.VideoCapture(video_file.name)
31
  frames = []
32
- success, frame = cap.read()
33
- while success:
34
- results = model(frame)
35
- annotated_frame = results[0].plot()
36
- frames.append(annotated_frame)
37
  success, frame = cap.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  cap.release()
 
39
  if not frames:
40
- return f"Error: No frames processed"
 
41
  height, width, _ = frames[0].shape
42
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
43
  tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
@@ -45,46 +69,69 @@ def predict_video(model, video_file):
45
  for frame in frames:
46
  out.write(frame)
47
  out.release()
48
- return tmp.name
 
 
 
 
49
  except Exception as e:
50
- return f"Error during video inference: {e}"
51
 
52
- # Main inference function: loads the custom model and processes the input media.
53
- # Returns a tuple: (annotated_image, annotated_video).
54
- # One element will be a file path and the other None, based on the media type.
55
- def inference(model_file, input_media, media_type):
56
  model = load_model(model_file)
57
- if isinstance(model, str):
58
- # An error occurred during model loading.
59
- return (model, None)
60
-
61
  if media_type == "Image":
62
- out_image = predict_image(model, input_media)
63
- return (out_image, None)
 
 
 
64
  elif media_type == "Video":
65
- out_video = predict_video(model, input_media)
66
- return (None, out_video)
 
67
  else:
68
- return ("Unsupported media type", None)
69
 
70
  # Define Gradio interface components.
 
71
  model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
 
 
72
  media_file_input = gr.File(label="Upload Image/Video File")
 
 
73
  media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
74
 
75
- # Define two outputs: one for images and one for videos.
76
- output_image = gr.Image(label="Annotated Image")
77
- output_video = gr.Video(label="Annotated Video")
 
 
 
 
 
 
 
 
78
 
79
- # Create a Gradio interface that returns a tuple: (image, video).
 
80
  iface = gr.Interface(
81
  fn=inference,
82
- inputs=[model_file_input, media_file_input, media_type_dropdown],
83
- outputs=[output_image, output_video],
84
- title="Custom YOLO Model Inference",
85
  description=(
86
- "Upload your custom YOLO model (detection, segmentation, or OBB) along with an image or video file "
87
- "to run inference. The system loads your model dynamically, processes the media, and displays the output."
 
88
  )
89
  )
90
 
 
2
  from ultralytics import YOLO
3
  import cv2
4
  import tempfile
5
+ import time
6
+ import numpy as np
7
 
8
+ # Load the custom YOLO model from the uploaded file.
9
  def load_model(model_file):
10
  try:
11
  model = YOLO(model_file.name)
 
13
  except Exception as e:
14
  return f"Error loading model: {e}"
15
 
16
+ # Run inference on an image, apply the confidence threshold, and save the result.
17
+ def predict_image(model, image, conf):
18
  try:
19
+ start_time = time.time()
20
+ # Pass the confidence threshold to the model (Ultralytics models accept this as a keyword argument).
21
+ results = model(image, conf=conf)
22
+ process_time = time.time() - start_time
23
+
24
+ # Use the model's built-in plot() method to overlay detections.
25
+ annotated_frame = results[0].plot()
26
+ # Count detections if available (assumes results[0] contains a 'boxes' attribute).
27
+ num_detections = len(results[0].boxes) if hasattr(results[0], "boxes") else "N/A"
28
+
29
+ # Write the annotated image to a temporary PNG file.
30
  tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
31
  cv2.imwrite(tmp.name, annotated_frame)
32
+ return tmp.name, process_time, num_detections
33
  except Exception as e:
34
+ return f"Error during image inference: {e}", None, None
35
 
36
+ # Run inference on a video by processing frames with a given frame step and saving the output.
37
+ def predict_video(model, video_file, conf, frame_step):
 
38
  try:
39
  cap = cv2.VideoCapture(video_file.name)
40
  frames = []
41
+ frame_count = 0
42
+ start_time = time.time()
43
+
44
+ while True:
 
45
  success, frame = cap.read()
46
+ if not success:
47
+ break
48
+
49
+ # Process only every nth frame (frame_step controls this).
50
+ if frame_count % frame_step == 0:
51
+ results = model(frame, conf=conf)
52
+ annotated_frame = results[0].plot()
53
+ frames.append(annotated_frame)
54
+ else:
55
+ # If skipping, add the original frame (or you could choose not to add anything).
56
+ frames.append(frame)
57
+ frame_count += 1
58
+
59
+ process_time = time.time() - start_time
60
  cap.release()
61
+
62
  if not frames:
63
+ return f"Error: No frames processed", None, None
64
+
65
  height, width, _ = frames[0].shape
66
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
67
  tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
 
69
  for frame in frames:
70
  out.write(frame)
71
  out.release()
72
+
73
+ # For the detection summary, aggregate the number of detections from the processed frames.
74
+ # (Note: For simplicity, this uses the detections from the first processed frame if available.)
75
+ num_detections = "See individual frames" # More elaborate aggregation logic can be added.
76
+ return tmp.name, process_time, num_detections
77
  except Exception as e:
78
+ return f"Error during video inference: {e}", None, None
79
 
80
+ # Main inference function.
81
+ # It now accepts additional parameters: confidence threshold and frame step (for videos).
82
+ # Returns a tuple with an output file path and a JSON-like dictionary with metadata.
83
+ def inference(model_file, input_media, media_type, conf, frame_step):
84
  model = load_model(model_file)
85
+ if isinstance(model, str): # An error occurred during model loading.
86
+ return model, {"processing_time": None, "detections": None}
87
+
88
+ # Process according to media type.
89
  if media_type == "Image":
90
+ out_file, process_time, detections = predict_image(model, input_media, conf)
91
+ # For API users, return both the output file path and a dictionary with metadata.
92
+ metadata = {"processing_time": process_time, "detections": detections}
93
+ return out_file, metadata
94
+
95
  elif media_type == "Video":
96
+ out_file, process_time, detections = predict_video(model, input_media, conf, frame_step)
97
+ metadata = {"processing_time": process_time, "detections": detections}
98
+ return out_file, metadata
99
  else:
100
+ return "Unsupported media type", {"processing_time": None, "detections": None}
101
 
102
  # Define Gradio interface components.
103
+ # File upload for the custom YOLO model (a .pt file).
104
  model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
105
+
106
+ # File upload for the image or video.
107
  media_file_input = gr.File(label="Upload Image/Video File")
108
+
109
+ # Radio button for selecting media type.
110
  media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
111
 
112
+ # Confidence slider (minimum detection confidence).
113
+ confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5, label="Detection Confidence Threshold")
114
+
115
+ # Frame step slider for video (how many frames to skip between processing).
116
+ frame_step_slider = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Frame Step (for Video Processing)")
117
+
118
+ # We define two outputs:
119
+ # 1. A File output that will show the annotated image or video.
120
+ # 2. A JSON/Text output that reports processing time and detections.
121
+ output_file = gr.File(label="Processed Output")
122
+ output_metadata = gr.JSON(label="Metadata")
123
 
124
+ # Create the Gradio interface.
125
+ # Note: For API clients, the JSON output (metadata) gives additional info on processing.
126
  iface = gr.Interface(
127
  fn=inference,
128
+ inputs=[model_file_input, media_file_input, media_type_dropdown, confidence_slider, frame_step_slider],
129
+ outputs=[output_file, output_metadata],
130
+ title="Enhanced Custom YOLO Model Inference",
131
  description=(
132
+ "Upload your custom YOLO model (supports detection, segmentation, or OBB), along with an image or video file. "
133
+ "Use the sliders to adjust the detection confidence and (for videos) the frame step for real-time performance. "
134
+ "The app returns an annotated output file and metadata (processing time and detection summary) for API use."
135
  )
136
  )
137