wuhp commited on
Commit
5454bc6
·
verified ·
1 Parent(s): 651f077

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -29
app.py CHANGED
@@ -3,25 +3,28 @@ from ultralytics import YOLO
3
  import cv2
4
  import tempfile
5
 
6
- # Function to load a custom YOLO model from an uploaded file.
7
  def load_model(model_file):
8
  try:
9
- # model_file is a TemporaryFile object. Use .name to get its path.
10
  model = YOLO(model_file.name)
11
  return model
12
  except Exception as e:
13
  return f"Error loading model: {e}"
14
 
15
- # Function to perform inference on an image.
16
  def predict_image(model, image):
17
  try:
18
  results = model(image)
19
- annotated_frame = results[0].plot() # This should work across detection, segmentation, or OBB models.
20
- return annotated_frame
 
 
 
21
  except Exception as e:
22
  return f"Error during image inference: {e}"
23
 
24
- # Function to perform inference on a video.
 
25
  def predict_video(model, video_file):
26
  try:
27
  cap = cv2.VideoCapture(video_file.name)
@@ -33,53 +36,55 @@ def predict_video(model, video_file):
33
  frames.append(annotated_frame)
34
  success, frame = cap.read()
35
  cap.release()
36
-
37
  if not frames:
38
- return "Error: No frames processed from video."
39
-
40
  height, width, _ = frames[0].shape
41
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
42
- temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
43
- out = cv2.VideoWriter(temp_video_file.name, fourcc, 20.0, (width, height))
44
  for frame in frames:
45
  out.write(frame)
46
  out.release()
47
- return temp_video_file.name
48
  except Exception as e:
49
  return f"Error during video inference: {e}"
50
 
51
- # Unified inference function that takes an uploaded model file, an input media file, and the selected media type.
 
 
52
  def inference(model_file, input_media, media_type):
53
  model = load_model(model_file)
54
- # Check if model loading resulted in an error message.
55
  if isinstance(model, str):
56
- return model
57
-
 
58
  if media_type == "Image":
59
- return predict_image(model, input_media)
 
60
  elif media_type == "Video":
61
- return predict_video(model, input_media)
 
62
  else:
63
- return "Unsupported media type."
64
 
65
- # Updated Gradio components:
66
- # - A file input for the custom YOLO model (.pt file)
67
- # - A file input for the image or video to process
68
- # - A radio button for selecting between image and video processing.
69
  model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
70
  media_file_input = gr.File(label="Upload Image/Video File")
71
  media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
72
- output_component = gr.File(label="Processed Output")
73
 
74
- # Create the Gradio interface.
 
 
 
 
75
  iface = gr.Interface(
76
  fn=inference,
77
  inputs=[model_file_input, media_file_input, media_type_dropdown],
78
- outputs=output_component,
79
  title="Custom YOLO Model Inference",
80
  description=(
81
- "Upload your custom YOLO model (for detection, segmentation, or OBB) along with an image or video file "
82
- "to run inference. The system dynamically loads your model and processes the media accordingly."
83
  )
84
  )
85
 
 
3
  import cv2
4
  import tempfile
5
 
6
+ # Load a custom YOLO model from the uploaded file.
7
  def load_model(model_file):
8
  try:
 
9
  model = YOLO(model_file.name)
10
  return model
11
  except Exception as e:
12
  return f"Error loading model: {e}"
13
 
14
+ # Run inference on an image and write the output to a PNG file.
15
  def predict_image(model, image):
16
  try:
17
  results = model(image)
18
+ annotated_frame = results[0].plot() # Works for detection, segmentation, and OBB models.
19
+ # Write annotated image to a temporary file.
20
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
21
+ cv2.imwrite(tmp.name, annotated_frame)
22
+ return tmp.name
23
  except Exception as e:
24
  return f"Error during image inference: {e}"
25
 
26
+ # Run inference on a video by processing frame-by-frame,
27
+ # and write the annotated video to an MP4 file.
28
  def predict_video(model, video_file):
29
  try:
30
  cap = cv2.VideoCapture(video_file.name)
 
36
  frames.append(annotated_frame)
37
  success, frame = cap.read()
38
  cap.release()
 
39
  if not frames:
40
+ return f"Error: No frames processed"
 
41
  height, width, _ = frames[0].shape
42
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
43
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
44
+ out = cv2.VideoWriter(tmp.name, fourcc, 20.0, (width, height))
45
  for frame in frames:
46
  out.write(frame)
47
  out.release()
48
+ return tmp.name
49
  except Exception as e:
50
  return f"Error during video inference: {e}"
51
 
52
+ # Main inference function: loads the custom model and processes the input media.
53
+ # Returns a tuple: (annotated_image, annotated_video).
54
+ # One element will be a file path and the other None, based on the media type.
55
  def inference(model_file, input_media, media_type):
56
  model = load_model(model_file)
 
57
  if isinstance(model, str):
58
+ # An error occurred during model loading.
59
+ return (model, None)
60
+
61
  if media_type == "Image":
62
+ out_image = predict_image(model, input_media)
63
+ return (out_image, None)
64
  elif media_type == "Video":
65
+ out_video = predict_video(model, input_media)
66
+ return (None, out_video)
67
  else:
68
+ return ("Unsupported media type", None)
69
 
70
+ # Define Gradio interface components.
 
 
 
71
  model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
72
  media_file_input = gr.File(label="Upload Image/Video File")
73
  media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
 
74
 
75
+ # Define two outputs: one for images and one for videos.
76
+ output_image = gr.Image(label="Annotated Image")
77
+ output_video = gr.Video(label="Annotated Video")
78
+
79
+ # Create a Gradio interface that returns a tuple: (image, video).
80
  iface = gr.Interface(
81
  fn=inference,
82
  inputs=[model_file_input, media_file_input, media_type_dropdown],
83
+ outputs=[output_image, output_video],
84
  title="Custom YOLO Model Inference",
85
  description=(
86
+ "Upload your custom YOLO model (detection, segmentation, or OBB) along with an image or video file "
87
+ "to run inference. The system loads your model dynamically, processes the media, and displays the output."
88
  )
89
  )
90