wuhp commited on
Commit
cb2f7e3
·
verified ·
1 Parent(s): a32514b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -43
app.py CHANGED
@@ -5,7 +5,7 @@ import tempfile
5
  import time
6
  import numpy as np
7
 
8
- # Load the custom YOLO model from the uploaded file.
9
  def load_model(model_file):
10
  try:
11
  model = YOLO(model_file.name)
@@ -13,27 +13,26 @@ def load_model(model_file):
13
  except Exception as e:
14
  return f"Error loading model: {e}"
15
 
16
- # Run inference on an image, apply the confidence threshold, and save the result.
17
  def predict_image(model, image, conf):
18
  try:
19
  start_time = time.time()
20
- # Pass the confidence threshold to the model (Ultralytics models accept this as a keyword argument).
21
  results = model(image, conf=conf)
22
  process_time = time.time() - start_time
23
 
24
- # Use the model's built-in plot() method to overlay detections.
25
  annotated_frame = results[0].plot()
26
- # Count detections if available (assumes results[0] contains a 'boxes' attribute).
27
- num_detections = len(results[0].boxes) if hasattr(results[0], "boxes") else "N/A"
28
 
29
- # Write the annotated image to a temporary PNG file.
30
- tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
31
- cv2.imwrite(tmp.name, annotated_frame)
32
- return tmp.name, process_time, num_detections
33
  except Exception as e:
34
  return f"Error during image inference: {e}", None, None
35
 
36
- # Run inference on a video by processing frames with a given frame step and saving the output.
37
  def predict_video(model, video_file, conf, frame_step):
38
  try:
39
  cap = cv2.VideoCapture(video_file.name)
@@ -46,13 +45,13 @@ def predict_video(model, video_file, conf, frame_step):
46
  if not success:
47
  break
48
 
49
- # Process only every nth frame (frame_step controls this).
50
  if frame_count % frame_step == 0:
51
  results = model(frame, conf=conf)
52
  annotated_frame = results[0].plot()
53
  frames.append(annotated_frame)
54
  else:
55
- # If skipping, add the original frame (or you could choose not to add anything).
56
  frames.append(frame)
57
  frame_count += 1
58
 
@@ -60,7 +59,7 @@ def predict_video(model, video_file, conf, frame_step):
60
  cap.release()
61
 
62
  if not frames:
63
- return f"Error: No frames processed", None, None
64
 
65
  height, width, _ = frames[0].shape
66
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
@@ -70,68 +69,69 @@ def predict_video(model, video_file, conf, frame_step):
70
  out.write(frame)
71
  out.release()
72
 
73
- # For the detection summary, aggregate the number of detections from the processed frames.
74
- # (Note: For simplicity, this uses the detections from the first processed frame if available.)
75
- num_detections = "See individual frames" # More elaborate aggregation logic can be added.
76
  return tmp.name, process_time, num_detections
77
  except Exception as e:
78
  return f"Error during video inference: {e}", None, None
79
 
80
  # Main inference function.
81
- # It now accepts additional parameters: confidence threshold and frame step (for videos).
82
- # Returns a tuple with an output file path and a JSON-like dictionary with metadata.
83
  def inference(model_file, input_media, media_type, conf, frame_step):
84
  model = load_model(model_file)
85
- if isinstance(model, str): # An error occurred during model loading.
86
- return model, {"processing_time": None, "detections": None}
87
 
88
- # Process according to media type.
89
  if media_type == "Image":
90
- out_file, process_time, detections = predict_image(model, input_media, conf)
91
- # For API users, return both the output file path and a dictionary with metadata.
92
  metadata = {"processing_time": process_time, "detections": detections}
93
- return out_file, metadata
94
 
95
  elif media_type == "Video":
96
- out_file, process_time, detections = predict_video(model, input_media, conf, frame_step)
97
  metadata = {"processing_time": process_time, "detections": detections}
98
- return out_file, metadata
99
  else:
100
- return "Unsupported media type", {"processing_time": None, "detections": None}
101
 
102
  # Define Gradio interface components.
103
- # File upload for the custom YOLO model (a .pt file).
104
  model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
105
 
106
- # File upload for the image or video.
107
  media_file_input = gr.File(label="Upload Image/Video File")
108
 
109
- # Radio button for selecting media type.
110
  media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
111
 
112
- # Confidence slider (minimum detection confidence).
113
  confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5, label="Detection Confidence Threshold")
114
 
115
- # Frame step slider for video (how many frames to skip between processing).
116
  frame_step_slider = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Frame Step (for Video Processing)")
117
 
118
- # We define two outputs:
119
- # 1. A File output that will show the annotated image or video.
120
- # 2. A JSON/Text output that reports processing time and detections.
121
- output_file = gr.File(label="Processed Output")
 
 
122
  output_metadata = gr.JSON(label="Metadata")
123
 
124
  # Create the Gradio interface.
125
- # Note: For API clients, the JSON output (metadata) gives additional info on processing.
126
  iface = gr.Interface(
127
  fn=inference,
128
  inputs=[model_file_input, media_file_input, media_type_dropdown, confidence_slider, frame_step_slider],
129
- outputs=[output_file, output_metadata],
130
- title="Enhanced Custom YOLO Model Inference",
131
  description=(
132
- "Upload your custom YOLO model (supports detection, segmentation, or OBB), along with an image or video file. "
133
- "Use the sliders to adjust the detection confidence and (for videos) the frame step for real-time performance. "
134
- "The app returns an annotated output file and metadata (processing time and detection summary) for API use."
 
 
135
  )
136
  )
137
 
 
5
  import time
6
  import numpy as np
7
 
8
+ # Load a custom YOLO model from the uploaded file.
9
  def load_model(model_file):
10
  try:
11
  model = YOLO(model_file.name)
 
13
  except Exception as e:
14
  return f"Error loading model: {e}"
15
 
16
+ # Run inference on an image and return a processed image as an np.ndarray.
17
  def predict_image(model, image, conf):
18
  try:
19
  start_time = time.time()
20
+ # Run inference with confidence threshold.
21
  results = model(image, conf=conf)
22
  process_time = time.time() - start_time
23
 
24
+ # Get the annotated image using the model's built-in plotting.
25
  annotated_frame = results[0].plot()
26
+ # Optional: Convert BGR (OpenCV default) to RGB if needed.
27
+ annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
28
 
29
+ # Count detections if available (assumes results[0].boxes exists).
30
+ num_detections = len(results[0].boxes) if hasattr(results[0], "boxes") else "N/A"
31
+ return annotated_frame, process_time, num_detections
 
32
  except Exception as e:
33
  return f"Error during image inference: {e}", None, None
34
 
35
+ # Run inference on a video by processing selected frames and return a processed video file.
36
  def predict_video(model, video_file, conf, frame_step):
37
  try:
38
  cap = cv2.VideoCapture(video_file.name)
 
45
  if not success:
46
  break
47
 
48
+ # Only process every nth frame determined by frame_step.
49
  if frame_count % frame_step == 0:
50
  results = model(frame, conf=conf)
51
  annotated_frame = results[0].plot()
52
  frames.append(annotated_frame)
53
  else:
54
+ # Optionally, append the original frame, or skip entirely.
55
  frames.append(frame)
56
  frame_count += 1
57
 
 
59
  cap.release()
60
 
61
  if not frames:
62
+ return "Error: No frames processed", None, None
63
 
64
  height, width, _ = frames[0].shape
65
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
 
69
  out.write(frame)
70
  out.release()
71
 
72
+ # For video, we return a placeholder for number of detections. (More logic can be added to aggregate detections.)
73
+ num_detections = "See individual frames"
 
74
  return tmp.name, process_time, num_detections
75
  except Exception as e:
76
  return f"Error during video inference: {e}", None, None
77
 
78
  # Main inference function.
79
+ # Returns a tuple: (annotated_image, annotated_video, metadata)
80
+ # For image inputs, the video output is None; for video inputs, the image output is None.
81
  def inference(model_file, input_media, media_type, conf, frame_step):
82
  model = load_model(model_file)
83
+ if isinstance(model, str): # This indicates an error during model loading.
84
+ return model, None, {"processing_time": None, "detections": None}
85
 
 
86
  if media_type == "Image":
87
+ out_img, process_time, detections = predict_image(model, input_media, conf)
 
88
  metadata = {"processing_time": process_time, "detections": detections}
89
+ return out_img, None, metadata
90
 
91
  elif media_type == "Video":
92
+ out_vid, process_time, detections = predict_video(model, input_media, conf, frame_step)
93
  metadata = {"processing_time": process_time, "detections": detections}
94
+ return None, out_vid, metadata
95
  else:
96
+ return "Unsupported media type", None, {"processing_time": None, "detections": None}
97
 
98
  # Define Gradio interface components.
99
+ # Component for uploading a custom YOLO model (.pt file).
100
  model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
101
 
102
+ # Component for uploading an image or video.
103
  media_file_input = gr.File(label="Upload Image/Video File")
104
 
105
+ # Radio button to choose media type.
106
  media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
107
 
108
+ # Detection confidence slider.
109
  confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5, label="Detection Confidence Threshold")
110
 
111
+ # Frame step slider (for video processing).
112
  frame_step_slider = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Frame Step (for Video Processing)")
113
 
114
+ # For display on the site:
115
+ # - Use gr.Image to display the processed image.
116
+ # - Use gr.Video to display the processed video.
117
+ # - Use gr.JSON to display the metadata.
118
+ output_image = gr.Image(label="Annotated Image")
119
+ output_video = gr.Video(label="Annotated Video")
120
  output_metadata = gr.JSON(label="Metadata")
121
 
122
  # Create the Gradio interface.
123
+ # Note: The function returns a triple: (processed image, processed video, metadata).
124
  iface = gr.Interface(
125
  fn=inference,
126
  inputs=[model_file_input, media_file_input, media_type_dropdown, confidence_slider, frame_step_slider],
127
+ outputs=[output_image, output_video, output_metadata],
128
+ title="Custom YOLO Model Inference for Real-Time Detection",
129
  description=(
130
+ "Upload your custom YOLO model (detection, segmentation, or OBB) along with an image or video file "
131
+ "to run inference. Adjust the detection confidence and frame step (for video) as needed. "
132
+ "The app shows the processed image/video and returns metadata for real-time API integration. "
133
+ "This is optimized for users who wish to host a YOLO model on Hugging Face and use it for real-time "
134
+ "object detection via the Gradio API."
135
  )
136
  )
137