Update app.py
Browse files
app.py
CHANGED
@@ -126,7 +126,7 @@ def preprocess_video_in(video_path, session_state):
|
|
126 |
session_state,
|
127 |
)
|
128 |
|
129 |
-
# Read the
|
130 |
cap = cv2.VideoCapture(video_path)
|
131 |
if not cap.isOpened():
|
132 |
print(f"Error: Could not open video at {video_path}.")
|
@@ -139,65 +139,61 @@ def preprocess_video_in(video_path, session_state):
|
|
139 |
session_state,
|
140 |
)
|
141 |
|
142 |
-
# For CPU optimization - determine video properties
|
143 |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
144 |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
145 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
146 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
147 |
-
|
148 |
print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames, {fps} FPS")
|
149 |
-
|
150 |
-
# Determine if we need to resize for CPU performance
|
151 |
target_width = 640 # Target width for processing on CPU
|
152 |
scale_factor = 1.0
|
153 |
-
|
154 |
if frame_width > target_width:
|
155 |
scale_factor = target_width / frame_width
|
156 |
new_width = int(frame_width * scale_factor)
|
157 |
new_height = int(frame_height * scale_factor)
|
158 |
print(f"Resizing video for CPU processing: {frame_width}x{frame_height} -> {new_width}x{new_height}")
|
159 |
-
|
160 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
frame_number = 0
|
162 |
first_frame = None
|
163 |
all_frames = []
|
164 |
-
|
165 |
-
# For CPU optimization, skip frames if video is too long
|
166 |
-
frame_stride = 1
|
167 |
-
if total_frames > 300: # If more than 300 frames
|
168 |
-
frame_stride = max(1, int(total_frames / 300)) # Process at most ~300 frames
|
169 |
-
print(f"Video has {total_frames} frames, using stride of {frame_stride} to reduce processing load")
|
170 |
-
|
171 |
while True:
|
172 |
ret, frame = cap.read()
|
173 |
if not ret:
|
174 |
break
|
175 |
-
|
176 |
-
if frame_number % frame_stride == 0:
|
177 |
try:
|
178 |
# Resize the frame if needed
|
179 |
if scale_factor != 1.0:
|
180 |
frame = cv2.resize(
|
181 |
-
frame,
|
182 |
-
(int(frame_width * scale_factor), int(frame_height * scale_factor)),
|
183 |
interpolation=cv2.INTER_AREA
|
184 |
)
|
185 |
-
|
186 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
187 |
frame = np.array(frame)
|
188 |
-
|
189 |
-
# Store the first frame
|
190 |
if first_frame is None:
|
191 |
first_frame = frame
|
192 |
all_frames.append(frame)
|
193 |
except Exception as e:
|
194 |
print(f"Error processing frame {frame_number}: {e}")
|
195 |
-
|
196 |
frame_number += 1
|
197 |
|
198 |
cap.release()
|
199 |
-
|
200 |
-
# Ensure we have at least one frame
|
201 |
if first_frame is None or len(all_frames) == 0:
|
202 |
print("Error: No frames could be extracted from the video.")
|
203 |
return (
|
@@ -208,9 +204,9 @@ def preprocess_video_in(video_path, session_state):
|
|
208 |
gr.update(value=0, visible=False), # progress_bar
|
209 |
session_state,
|
210 |
)
|
211 |
-
|
212 |
print(f"Successfully extracted {len(all_frames)} frames from video")
|
213 |
-
|
214 |
session_state["first_frame"] = copy.deepcopy(first_frame)
|
215 |
session_state["all_frames"] = all_frames
|
216 |
session_state["frame_stride"] = frame_stride
|
@@ -227,7 +223,7 @@ def preprocess_video_in(video_path, session_state):
|
|
227 |
import traceback
|
228 |
traceback.print_exc()
|
229 |
session_state["inference_state"] = None
|
230 |
-
|
231 |
return [
|
232 |
gr.update(open=False), # video_in_drawer
|
233 |
first_frame, # points_map
|
@@ -320,6 +316,11 @@ def segment_with_points(
|
|
320 |
print(f"Resizing mask from {out_mask.shape[:2]} to {h}x{w}")
|
321 |
# Use numpy/PIL for resizing to avoid OpenCV issues
|
322 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
323 |
mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
|
324 |
mask_img = mask_img.resize((w, h), Image.NEAREST)
|
325 |
out_mask = np.array(mask_img) > 0
|
@@ -449,7 +450,7 @@ def propagate_to_all(
|
|
449 |
print("Starting propagate_in_video on CPU")
|
450 |
|
451 |
# Get the count for progress reporting (estimate)
|
452 |
-
all_frames_count =
|
453 |
|
454 |
# Now do the actual processing with progress updates
|
455 |
current_frame = 0
|
@@ -494,7 +495,7 @@ def propagate_to_all(
|
|
494 |
progress(0.5, desc="Rendering video")
|
495 |
|
496 |
# Limit to max 50 frames for CPU processing
|
497 |
-
max_output_frames =
|
498 |
vis_frame_stride = max(1, total_frames // max_output_frames)
|
499 |
print(f"Using stride of {vis_frame_stride} for output video generation")
|
500 |
|
@@ -543,6 +544,10 @@ def propagate_to_all(
|
|
543 |
if mask_h != frame_h or mask_w != frame_w:
|
544 |
print(f"Resizing mask from {mask_h}x{mask_w} to {frame_h}x{frame_w}")
|
545 |
try:
|
|
|
|
|
|
|
|
|
546 |
mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
|
547 |
mask_img = mask_img.resize((frame_w, frame_h), Image.NEAREST)
|
548 |
out_mask = np.array(mask_img) > 0
|
@@ -758,69 +763,4 @@ with gr.Blocks() as demo:
|
|
758 |
queue=False,
|
759 |
)
|
760 |
|
761 |
-
# triggered when we click
|
762 |
-
points_map.select(
|
763 |
-
fn=segment_with_points,
|
764 |
-
inputs=[
|
765 |
-
point_type, # "include" or "exclude"
|
766 |
-
session_state,
|
767 |
-
],
|
768 |
-
outputs=[
|
769 |
-
points_map, # updated image with points
|
770 |
-
output_image,
|
771 |
-
session_state,
|
772 |
-
],
|
773 |
-
queue=False,
|
774 |
-
)
|
775 |
-
|
776 |
-
# Clear every points clicked and added to the map
|
777 |
-
clear_points_btn.click(
|
778 |
-
fn=clear_points,
|
779 |
-
inputs=session_state,
|
780 |
-
outputs=[
|
781 |
-
points_map,
|
782 |
-
output_image,
|
783 |
-
output_video,
|
784 |
-
progress_bar,
|
785 |
-
session_state,
|
786 |
-
],
|
787 |
-
queue=False,
|
788 |
-
)
|
789 |
-
|
790 |
-
reset_btn.click(
|
791 |
-
fn=reset,
|
792 |
-
inputs=session_state,
|
793 |
-
outputs=[
|
794 |
-
video_in,
|
795 |
-
video_in_drawer,
|
796 |
-
points_map,
|
797 |
-
output_image,
|
798 |
-
output_video,
|
799 |
-
progress_bar,
|
800 |
-
session_state,
|
801 |
-
],
|
802 |
-
queue=False,
|
803 |
-
)
|
804 |
-
|
805 |
-
propagate_btn.click(
|
806 |
-
fn=update_ui,
|
807 |
-
inputs=[],
|
808 |
-
outputs=[output_video, progress_bar],
|
809 |
-
queue=False,
|
810 |
-
).then(
|
811 |
-
fn=propagate_to_all,
|
812 |
-
inputs=[
|
813 |
-
video_in,
|
814 |
-
session_state,
|
815 |
-
],
|
816 |
-
outputs=[
|
817 |
-
output_video,
|
818 |
-
progress_bar,
|
819 |
-
session_state,
|
820 |
-
],
|
821 |
-
queue=True, # Use queue for CPU processing
|
822 |
-
)
|
823 |
-
|
824 |
-
|
825 |
-
demo.queue()
|
826 |
-
demo.launch()
|
|
|
126 |
session_state,
|
127 |
)
|
128 |
|
129 |
+
# Read the video
|
130 |
cap = cv2.VideoCapture(video_path)
|
131 |
if not cap.isOpened():
|
132 |
print(f"Error: Could not open video at {video_path}.")
|
|
|
139 |
session_state,
|
140 |
)
|
141 |
|
|
|
142 |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
143 |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
144 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
145 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
146 |
+
|
147 |
print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames, {fps} FPS")
|
148 |
+
|
|
|
149 |
target_width = 640 # Target width for processing on CPU
|
150 |
scale_factor = 1.0
|
151 |
+
|
152 |
if frame_width > target_width:
|
153 |
scale_factor = target_width / frame_width
|
154 |
new_width = int(frame_width * scale_factor)
|
155 |
new_height = int(frame_height * scale_factor)
|
156 |
print(f"Resizing video for CPU processing: {frame_width}x{frame_height} -> {new_width}x{new_height}")
|
157 |
+
|
158 |
+
# Even more aggressive frame skipping for very long videos on CPU
|
159 |
+
frame_stride = 1
|
160 |
+
max_frames = 150 # Maximum number of frames to process
|
161 |
+
if total_frames > max_frames:
|
162 |
+
frame_stride = max(1, int(total_frames / max_frames))
|
163 |
+
print(f"Video has {total_frames} frames, using stride of {frame_stride} to limit to {max_frames}")
|
164 |
+
|
165 |
frame_number = 0
|
166 |
first_frame = None
|
167 |
all_frames = []
|
168 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
while True:
|
170 |
ret, frame = cap.read()
|
171 |
if not ret:
|
172 |
break
|
173 |
+
|
174 |
+
if frame_number % frame_stride == 0:
|
175 |
try:
|
176 |
# Resize the frame if needed
|
177 |
if scale_factor != 1.0:
|
178 |
frame = cv2.resize(
|
179 |
+
frame,
|
180 |
+
(int(frame_width * scale_factor), int(frame_height * scale_factor)),
|
181 |
interpolation=cv2.INTER_AREA
|
182 |
)
|
183 |
+
|
184 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
185 |
frame = np.array(frame)
|
186 |
+
|
|
|
187 |
if first_frame is None:
|
188 |
first_frame = frame
|
189 |
all_frames.append(frame)
|
190 |
except Exception as e:
|
191 |
print(f"Error processing frame {frame_number}: {e}")
|
192 |
+
|
193 |
frame_number += 1
|
194 |
|
195 |
cap.release()
|
196 |
+
|
|
|
197 |
if first_frame is None or len(all_frames) == 0:
|
198 |
print("Error: No frames could be extracted from the video.")
|
199 |
return (
|
|
|
204 |
gr.update(value=0, visible=False), # progress_bar
|
205 |
session_state,
|
206 |
)
|
207 |
+
|
208 |
print(f"Successfully extracted {len(all_frames)} frames from video")
|
209 |
+
|
210 |
session_state["first_frame"] = copy.deepcopy(first_frame)
|
211 |
session_state["all_frames"] = all_frames
|
212 |
session_state["frame_stride"] = frame_stride
|
|
|
223 |
import traceback
|
224 |
traceback.print_exc()
|
225 |
session_state["inference_state"] = None
|
226 |
+
|
227 |
return [
|
228 |
gr.update(open=False), # video_in_drawer
|
229 |
first_frame, # points_map
|
|
|
316 |
print(f"Resizing mask from {out_mask.shape[:2]} to {h}x{w}")
|
317 |
# Use numpy/PIL for resizing to avoid OpenCV issues
|
318 |
from PIL import Image
|
319 |
+
|
320 |
+
# Ensure mask is boolean type
|
321 |
+
if out_mask.dtype != np.bool_:
|
322 |
+
out_mask = out_mask > 0
|
323 |
+
|
324 |
mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
|
325 |
mask_img = mask_img.resize((w, h), Image.NEAREST)
|
326 |
out_mask = np.array(mask_img) > 0
|
|
|
450 |
print("Starting propagate_in_video on CPU")
|
451 |
|
452 |
# Get the count for progress reporting (estimate)
|
453 |
+
all_frames_count = 100 # Reasonable estimate
|
454 |
|
455 |
# Now do the actual processing with progress updates
|
456 |
current_frame = 0
|
|
|
495 |
progress(0.5, desc="Rendering video")
|
496 |
|
497 |
# Limit to max 50 frames for CPU processing
|
498 |
+
max_output_frames = 30
|
499 |
vis_frame_stride = max(1, total_frames // max_output_frames)
|
500 |
print(f"Using stride of {vis_frame_stride} for output video generation")
|
501 |
|
|
|
544 |
if mask_h != frame_h or mask_w != frame_w:
|
545 |
print(f"Resizing mask from {mask_h}x{mask_w} to {frame_h}x{frame_w}")
|
546 |
try:
|
547 |
+
# Ensure mask is boolean type
|
548 |
+
if out_mask.dtype != np.bool_:
|
549 |
+
out_mask = out_mask > 0
|
550 |
+
|
551 |
mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
|
552 |
mask_img = mask_img.resize((frame_w, frame_h), Image.NEAREST)
|
553 |
out_mask = np.array(mask_img) > 0
|
|
|
763 |
queue=False,
|
764 |
)
|
765 |
|
766 |
+
# triggered when we click
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|