Medha Sawhney commited on
Commit
091eba9
·
1 Parent(s): 53aaa4a

support to show all frames

Browse files
MEMTrack/src/GenerateVideo.py CHANGED
@@ -3,32 +3,49 @@ import os
3
  import argparse
4
  from natsort import natsorted
5
 
6
- def create_video(data_dir, image_dir, video_name,fps):
7
  # choose codec according to format needed
8
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
9
  #print(data_dir)
10
- img_sample = cv2.imread(os.path.join(image_dir,"0.png"))
11
- #print(img_sample.shape)
 
 
 
 
 
 
 
12
  height, width, channels = img_sample.shape
13
 
14
  video = cv2.VideoWriter(data_dir + video_name + ".mp4", fourcc, fps, (width, height))
15
  print("Image dir: ", image_dir)
16
- print("Number of frames: ", len(os.listdir(image_dir)))
17
  print("Video output fps: ", fps)
18
- for frame in natsorted(os.listdir(image_dir)):
19
- #print(frame)
20
- img = cv2.imread(os.path.join(image_dir, frame))
21
- video.write(img)
 
 
 
 
 
22
  video.release()
23
 
24
- def gen_tracking_video(video_num, fps=60, data_path=None, custom_test_dir=None):
25
  if custom_test_dir:
26
  data_dir = custom_test_dir
27
  else:
28
  data_dir = data_path + f"/data_video{video_num}_feature_optical_flow_median_back_2pyr_18win_background_img/"
29
- image_dir = data_dir + "/test/tracklets-filtered/"
30
- video_name = f'video{video_num}-tracklets-filtered-{fps}'
31
- create_video(data_dir, image_dir, video_name,fps)
 
 
 
 
 
32
  return os.path.join(data_dir, video_name)+ ".mp4"
33
 
34
  if __name__ == "__main__":
 
3
  import argparse
4
  from natsort import natsorted
5
 
6
+ def create_video(data_dir, image_dir, video_name,fps, all_image_path=None, all_images=False):
7
  # choose codec according to format needed
8
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
9
  #print(data_dir)
10
+ plotted_images = natsorted(os.listdir(image_dir))
11
+ images = plotted_images
12
+ if all_images:
13
+ all_images_list = natsorted(os.listdir(all_image_path))
14
+ images = all_images_list
15
+
16
+ img_sample = cv2.imread(os.path.join(image_dir,images[0]))
17
+ print(images[0])
18
+ print(img_sample.shape)
19
  height, width, channels = img_sample.shape
20
 
21
  video = cv2.VideoWriter(data_dir + video_name + ".mp4", fourcc, fps, (width, height))
22
  print("Image dir: ", image_dir)
23
+ print("Number of frames: ", len(images))
24
  print("Video output fps: ", fps)
25
+
26
+ for frame in natsorted(images):
27
+ if frame in plotted_images:
28
+ img = cv2.imread(os.path.join(image_dir, frame))
29
+ video.write(img)
30
+ else:
31
+ img = cv2.imread(os.path.join(all_image_path, frame))
32
+ video.write(img)
33
+
34
  video.release()
35
 
36
+ def gen_tracking_video(video_num, fps=60, data_path=None, custom_test_dir=None, all_images=False):
37
  if custom_test_dir:
38
  data_dir = custom_test_dir
39
  else:
40
  data_dir = data_path + f"/data_video{video_num}_feature_optical_flow_median_back_2pyr_18win_background_img/"
41
+ image_dir_plotted = data_dir + "/test/tracklets-filtered/"
42
+ if all_images:
43
+ video_name = f'video{video_num}-tracklets-filtered-{fps}-all-frames'
44
+ else:
45
+ video_name = f'video{video_num}-tracklets-filtered-{fps}'
46
+ all_img_path = data_dir + f"/test/images/"
47
+
48
+ create_video(data_dir, image_dir_plotted, video_name,fps, all_image_path=all_img_path, all_images=all_images)
49
  return os.path.join(data_dir, video_name)+ ".mp4"
50
 
51
  if __name__ == "__main__":
MEMTrack/src/TrackingAnalysis.py CHANGED
@@ -298,7 +298,7 @@ def analyse_tracking(video_num, min_track_length=60, custom_test_dir=None, data_
298
  if plot :
299
  for image_id in image_id_filtered:
300
  #print(image_id)
301
- newname = save_path + str(image_id) + '.png'
302
  det_img = cv2.imread(os.path.join(img_path,str(image_id))+".tif")
303
  det_img_gt_only = det_img.copy()
304
  det_img_p_only = det_img.copy()
 
298
  if plot :
299
  for image_id in image_id_filtered:
300
  #print(image_id)
301
+ newname = save_path + str(image_id) + '.tif'
302
  det_img = cv2.imread(os.path.join(img_path,str(image_id))+".tif")
303
  det_img_gt_only = det_img.copy()
304
  det_img_p_only = det_img.copy()