Spaces:
bla
/
Runtime error

bla commited on
Commit
628bfb2
·
verified ·
1 Parent(s): 5dc8194

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -96
app.py CHANGED
@@ -126,7 +126,7 @@ def preprocess_video_in(video_path, session_state):
126
  session_state,
127
  )
128
 
129
- # Read the first frame
130
  cap = cv2.VideoCapture(video_path)
131
  if not cap.isOpened():
132
  print(f"Error: Could not open video at {video_path}.")
@@ -139,65 +139,61 @@ def preprocess_video_in(video_path, session_state):
139
  session_state,
140
  )
141
 
142
- # For CPU optimization - determine video properties
143
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
144
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
145
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
146
  fps = cap.get(cv2.CAP_PROP_FPS)
147
-
148
  print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames, {fps} FPS")
149
-
150
- # Determine if we need to resize for CPU performance
151
  target_width = 640 # Target width for processing on CPU
152
  scale_factor = 1.0
153
-
154
  if frame_width > target_width:
155
  scale_factor = target_width / frame_width
156
  new_width = int(frame_width * scale_factor)
157
  new_height = int(frame_height * scale_factor)
158
  print(f"Resizing video for CPU processing: {frame_width}x{frame_height} -> {new_width}x{new_height}")
159
-
160
- # Read frames - for CPU we'll be more selective about which frames to keep
 
 
 
 
 
 
161
  frame_number = 0
162
  first_frame = None
163
  all_frames = []
164
-
165
- # For CPU optimization, skip frames if video is too long
166
- frame_stride = 1
167
- if total_frames > 300: # If more than 300 frames
168
- frame_stride = max(1, int(total_frames / 300)) # Process at most ~300 frames
169
- print(f"Video has {total_frames} frames, using stride of {frame_stride} to reduce processing load")
170
-
171
  while True:
172
  ret, frame = cap.read()
173
  if not ret:
174
  break
175
-
176
- if frame_number % frame_stride == 0: # Process every frame_stride frames
177
  try:
178
  # Resize the frame if needed
179
  if scale_factor != 1.0:
180
  frame = cv2.resize(
181
- frame,
182
- (int(frame_width * scale_factor), int(frame_height * scale_factor)),
183
  interpolation=cv2.INTER_AREA
184
  )
185
-
186
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
187
  frame = np.array(frame)
188
-
189
- # Store the first frame
190
  if first_frame is None:
191
  first_frame = frame
192
  all_frames.append(frame)
193
  except Exception as e:
194
  print(f"Error processing frame {frame_number}: {e}")
195
-
196
  frame_number += 1
197
 
198
  cap.release()
199
-
200
- # Ensure we have at least one frame
201
  if first_frame is None or len(all_frames) == 0:
202
  print("Error: No frames could be extracted from the video.")
203
  return (
@@ -208,9 +204,9 @@ def preprocess_video_in(video_path, session_state):
208
  gr.update(value=0, visible=False), # progress_bar
209
  session_state,
210
  )
211
-
212
  print(f"Successfully extracted {len(all_frames)} frames from video")
213
-
214
  session_state["first_frame"] = copy.deepcopy(first_frame)
215
  session_state["all_frames"] = all_frames
216
  session_state["frame_stride"] = frame_stride
@@ -227,7 +223,7 @@ def preprocess_video_in(video_path, session_state):
227
  import traceback
228
  traceback.print_exc()
229
  session_state["inference_state"] = None
230
-
231
  return [
232
  gr.update(open=False), # video_in_drawer
233
  first_frame, # points_map
@@ -320,6 +316,11 @@ def segment_with_points(
320
  print(f"Resizing mask from {out_mask.shape[:2]} to {h}x{w}")
321
  # Use numpy/PIL for resizing to avoid OpenCV issues
322
  from PIL import Image
 
 
 
 
 
323
  mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
324
  mask_img = mask_img.resize((w, h), Image.NEAREST)
325
  out_mask = np.array(mask_img) > 0
@@ -449,7 +450,7 @@ def propagate_to_all(
449
  print("Starting propagate_in_video on CPU")
450
 
451
  # Get the count for progress reporting (estimate)
452
- all_frames_count = 300 # Reasonable estimate
453
 
454
  # Now do the actual processing with progress updates
455
  current_frame = 0
@@ -494,7 +495,7 @@ def propagate_to_all(
494
  progress(0.5, desc="Rendering video")
495
 
496
  # Limit to max 50 frames for CPU processing
497
- max_output_frames = 50
498
  vis_frame_stride = max(1, total_frames // max_output_frames)
499
  print(f"Using stride of {vis_frame_stride} for output video generation")
500
 
@@ -543,6 +544,10 @@ def propagate_to_all(
543
  if mask_h != frame_h or mask_w != frame_w:
544
  print(f"Resizing mask from {mask_h}x{mask_w} to {frame_h}x{frame_w}")
545
  try:
 
 
 
 
546
  mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
547
  mask_img = mask_img.resize((frame_w, frame_h), Image.NEAREST)
548
  out_mask = np.array(mask_img) > 0
@@ -758,69 +763,4 @@ with gr.Blocks() as demo:
758
  queue=False,
759
  )
760
 
761
- # triggered when we click on image to add new points
762
- points_map.select(
763
- fn=segment_with_points,
764
- inputs=[
765
- point_type, # "include" or "exclude"
766
- session_state,
767
- ],
768
- outputs=[
769
- points_map, # updated image with points
770
- output_image,
771
- session_state,
772
- ],
773
- queue=False,
774
- )
775
-
776
- # Clear every points clicked and added to the map
777
- clear_points_btn.click(
778
- fn=clear_points,
779
- inputs=session_state,
780
- outputs=[
781
- points_map,
782
- output_image,
783
- output_video,
784
- progress_bar,
785
- session_state,
786
- ],
787
- queue=False,
788
- )
789
-
790
- reset_btn.click(
791
- fn=reset,
792
- inputs=session_state,
793
- outputs=[
794
- video_in,
795
- video_in_drawer,
796
- points_map,
797
- output_image,
798
- output_video,
799
- progress_bar,
800
- session_state,
801
- ],
802
- queue=False,
803
- )
804
-
805
- propagate_btn.click(
806
- fn=update_ui,
807
- inputs=[],
808
- outputs=[output_video, progress_bar],
809
- queue=False,
810
- ).then(
811
- fn=propagate_to_all,
812
- inputs=[
813
- video_in,
814
- session_state,
815
- ],
816
- outputs=[
817
- output_video,
818
- progress_bar,
819
- session_state,
820
- ],
821
- queue=True, # Use queue for CPU processing
822
- )
823
-
824
-
825
- demo.queue()
826
- demo.launch()
 
126
  session_state,
127
  )
128
 
129
+ # Read the video
130
  cap = cv2.VideoCapture(video_path)
131
  if not cap.isOpened():
132
  print(f"Error: Could not open video at {video_path}.")
 
139
  session_state,
140
  )
141
 
 
142
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
143
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
144
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
145
  fps = cap.get(cv2.CAP_PROP_FPS)
146
+
147
  print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames, {fps} FPS")
148
+
 
149
  target_width = 640 # Target width for processing on CPU
150
  scale_factor = 1.0
151
+
152
  if frame_width > target_width:
153
  scale_factor = target_width / frame_width
154
  new_width = int(frame_width * scale_factor)
155
  new_height = int(frame_height * scale_factor)
156
  print(f"Resizing video for CPU processing: {frame_width}x{frame_height} -> {new_width}x{new_height}")
157
+
158
+ # Even more aggressive frame skipping for very long videos on CPU
159
+ frame_stride = 1
160
+ max_frames = 150 # Maximum number of frames to process
161
+ if total_frames > max_frames:
162
+ frame_stride = max(1, int(total_frames / max_frames))
163
+ print(f"Video has {total_frames} frames, using stride of {frame_stride} to limit to {max_frames}")
164
+
165
  frame_number = 0
166
  first_frame = None
167
  all_frames = []
168
+
 
 
 
 
 
 
169
  while True:
170
  ret, frame = cap.read()
171
  if not ret:
172
  break
173
+
174
+ if frame_number % frame_stride == 0:
175
  try:
176
  # Resize the frame if needed
177
  if scale_factor != 1.0:
178
  frame = cv2.resize(
179
+ frame,
180
+ (int(frame_width * scale_factor), int(frame_height * scale_factor)),
181
  interpolation=cv2.INTER_AREA
182
  )
183
+
184
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
185
  frame = np.array(frame)
186
+
 
187
  if first_frame is None:
188
  first_frame = frame
189
  all_frames.append(frame)
190
  except Exception as e:
191
  print(f"Error processing frame {frame_number}: {e}")
192
+
193
  frame_number += 1
194
 
195
  cap.release()
196
+
 
197
  if first_frame is None or len(all_frames) == 0:
198
  print("Error: No frames could be extracted from the video.")
199
  return (
 
204
  gr.update(value=0, visible=False), # progress_bar
205
  session_state,
206
  )
207
+
208
  print(f"Successfully extracted {len(all_frames)} frames from video")
209
+
210
  session_state["first_frame"] = copy.deepcopy(first_frame)
211
  session_state["all_frames"] = all_frames
212
  session_state["frame_stride"] = frame_stride
 
223
  import traceback
224
  traceback.print_exc()
225
  session_state["inference_state"] = None
226
+
227
  return [
228
  gr.update(open=False), # video_in_drawer
229
  first_frame, # points_map
 
316
  print(f"Resizing mask from {out_mask.shape[:2]} to {h}x{w}")
317
  # Use numpy/PIL for resizing to avoid OpenCV issues
318
  from PIL import Image
319
+
320
+ # Ensure mask is boolean type
321
+ if out_mask.dtype != np.bool_:
322
+ out_mask = out_mask > 0
323
+
324
  mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
325
  mask_img = mask_img.resize((w, h), Image.NEAREST)
326
  out_mask = np.array(mask_img) > 0
 
450
  print("Starting propagate_in_video on CPU")
451
 
452
  # Get the count for progress reporting (estimate)
453
+ all_frames_count = 100 # Reasonable estimate
454
 
455
  # Now do the actual processing with progress updates
456
  current_frame = 0
 
495
  progress(0.5, desc="Rendering video")
496
 
497
  # Limit to max 50 frames for CPU processing
498
+ max_output_frames = 30
499
  vis_frame_stride = max(1, total_frames // max_output_frames)
500
  print(f"Using stride of {vis_frame_stride} for output video generation")
501
 
 
544
  if mask_h != frame_h or mask_w != frame_w:
545
  print(f"Resizing mask from {mask_h}x{mask_w} to {frame_h}x{frame_w}")
546
  try:
547
+ # Ensure mask is boolean type
548
+ if out_mask.dtype != np.bool_:
549
+ out_mask = out_mask > 0
550
+
551
  mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
552
  mask_img = mask_img.resize((frame_w, frame_h), Image.NEAREST)
553
  out_mask = np.array(mask_img) > 0
 
763
  queue=False,
764
  )
765
 
766
+ # triggered when we click