xmrt commited on
Commit
d9df1e2
·
1 Parent(s): eef5e41
Files changed (1) hide show
  1. app.py +26 -7
app.py CHANGED
@@ -30,6 +30,30 @@ print("[INFO]: Imported modules!")
30
  track_model = YOLO('yolov8n.pt') # Load an official Detect model
31
  print("[INFO]: Downloaded models!")
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def tracking(video, model, boxes=True):
34
  print("[INFO] Is cuda available? ", torch.cuda.is_available())
35
  print(device)
@@ -45,17 +69,12 @@ def tracking(video, model, boxes=True):
45
  return annotated_frame
46
 
47
  def show_tracking(video_content):
 
 
48
 
49
  # https://docs.ultralytics.com/datasets/detect/coco/
50
  video = cv2.VideoCapture(video_content)
51
 
52
- fps = video.get(cv2.CAP_PROP_FPS) # OpenCV v2.x used "CV_CAP_PROP_FPS"
53
- frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
54
- duration = frame_count/fps
55
-
56
- if duration > 10:
57
- raise gr.Error("Please provide or record a video shorter than 10 seconds...")
58
-
59
  # Track
60
  video_track = tracking(video_content, track_model.track)
61
 
 
30
  track_model = YOLO('yolov8n.pt') # Load an official Detect model
31
  print("[INFO]: Downloaded models!")
32
 
33
+
34
+
35
+ def check_extension(video):
36
+
37
+ clip = moviepy.VideoFileClip(video)
38
+
39
+ if clip.duration > 10:
40
+ raise gr.Error("Please provide or record a video shorter than 10 seconds...")
41
+
42
+ split_tup = os.path.splitext(video)
43
+
44
+ # extract the file name and extension
45
+ file_name = split_tup[0]
46
+ file_extension = split_tup[1]
47
+
48
+ if file_extension != ".mp4":
49
+ print("Converting to mp4")
50
+
51
+ video = file_name+".mp4"
52
+ clip.write_videofile(video, threads = 8)
53
+
54
+ return video
55
+
56
+
57
  def tracking(video, model, boxes=True):
58
  print("[INFO] Is cuda available? ", torch.cuda.is_available())
59
  print(device)
 
69
  return annotated_frame
70
 
71
  def show_tracking(video_content):
72
+
73
+ video = check_extension(video_content)
74
 
75
  # https://docs.ultralytics.com/datasets/detect/coco/
76
  video = cv2.VideoCapture(video_content)
77
 
 
 
 
 
 
 
 
78
  # Track
79
  video_track = tracking(video_content, track_model.track)
80