typo
Browse files
app.py
CHANGED
@@ -120,7 +120,11 @@ def reset(
|
|
120 |
)
|
121 |
|
122 |
|
123 |
-
def clear_points(
|
|
|
|
|
|
|
|
|
124 |
session_id = request.session_id
|
125 |
predictor.to("cpu")
|
126 |
session_input_points = []
|
@@ -237,9 +241,7 @@ def segment_with_points(
|
|
237 |
print(f"TRACKING INPUT LABEL: {session_input_labels}")
|
238 |
|
239 |
# Open the image and get its dimensions
|
240 |
-
transparent_background Image.fromarray(session_first_frame).convert(
|
241 |
-
"RGBA"
|
242 |
-
)
|
243 |
w, h = transparent_background.size
|
244 |
|
245 |
# Define the circle radius as a fraction of the smaller dimension
|
@@ -277,7 +279,12 @@ def segment_with_points(
|
|
277 |
first_frame_output = Image.alpha_composite(transparent_background, mask_image)
|
278 |
|
279 |
torch.cuda.empty_cache()
|
280 |
-
return
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
|
283 |
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
@@ -308,7 +315,7 @@ def propagate_to_all(
|
|
308 |
torch.backends.cudnn.allow_tf32 = True
|
309 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
310 |
if (
|
311 |
-
len
|
312 |
or video_in is None
|
313 |
or global_inference_states[session_id] is None
|
314 |
):
|
|
|
120 |
)
|
121 |
|
122 |
|
123 |
+
def clear_points(
|
124 |
+
session_input_points,
|
125 |
+
session_input_labels,
|
126 |
+
request: gr.Request,
|
127 |
+
):
|
128 |
session_id = request.session_id
|
129 |
predictor.to("cpu")
|
130 |
session_input_points = []
|
|
|
241 |
print(f"TRACKING INPUT LABEL: {session_input_labels}")
|
242 |
|
243 |
# Open the image and get its dimensions
|
244 |
+
transparent_background = Image.fromarray(session_first_frame).convert("RGBA")
|
|
|
|
|
245 |
w, h = transparent_background.size
|
246 |
|
247 |
# Define the circle radius as a fraction of the smaller dimension
|
|
|
279 |
first_frame_output = Image.alpha_composite(transparent_background, mask_image)
|
280 |
|
281 |
torch.cuda.empty_cache()
|
282 |
+
return (
|
283 |
+
selected_point_map,
|
284 |
+
first_frame_output,
|
285 |
+
session_input_points,
|
286 |
+
session_input_labels,
|
287 |
+
)
|
288 |
|
289 |
|
290 |
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
|
|
315 |
torch.backends.cudnn.allow_tf32 = True
|
316 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
317 |
if (
|
318 |
+
len(session_input_points) == 0
|
319 |
or video_in is None
|
320 |
or global_inference_states[session_id] is None
|
321 |
):
|