goryhon commited on
Commit
2ccc12e
·
verified ·
1 Parent(s): e15cd4d

Update web-demos/hugging_face/app.py

Browse files
Files changed (1) hide show
  1. web-demos/hugging_face/app.py +16 -2
web-demos/hugging_face/app.py CHANGED
@@ -13,6 +13,7 @@ import torch
13
  import torchvision
14
  import numpy as np
15
  import gradio as gr
 
16
 
17
  from tools.painter import mask_painter
18
  from track_anything import TrackingAnything
@@ -252,10 +253,14 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
252
  template_mask[0][0]=1
253
  operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
254
  # return video_output, video_state, interactive_state, operation_error
255
- masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
 
256
  # clear GPU memory
257
  model.cutie.clear_memory()
258
 
 
 
 
259
  if interactive_state["track_end_number"]:
260
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
261
  video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
@@ -267,6 +272,10 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
267
 
268
  video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=float(fps)) # import video_input to name the output video
269
  interactive_state["inference_times"] += 1
 
 
 
 
270
 
271
  print("Tracking resolution:", following_frames[0].shape)
272
 
@@ -325,7 +334,7 @@ def inpaint_video(video_state, resize_ratio_number, dilate_radius_number, raft_i
325
 
326
 
327
  # generate video after vos inference
328
- def generate_video_from_frames(frames, output_path, fps=30):
329
  """
330
  Generates a video from a list of frames.
331
 
@@ -356,6 +365,11 @@ def generate_video_from_frames(frames, output_path, fps=30):
356
  if not os.path.exists(os.path.dirname(output_path)):
357
  os.makedirs(os.path.dirname(output_path))
358
 
 
 
 
 
 
359
  # Write the video
360
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
361
  return output_path
 
13
  import torchvision
14
  import numpy as np
15
  import gradio as gr
16
+ from PIL import Image
17
 
18
  from tools.painter import mask_painter
19
  from track_anything import TrackingAnything
 
253
  template_mask[0][0]=1
254
  operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
255
  # return video_output, video_state, interactive_state, operation_error
256
+ masks, logits, painted_images, alpha_visuals = model.generator(images=following_frames, template_mask=template_mask)
257
+
258
  # clear GPU memory
259
  model.cutie.clear_memory()
260
 
261
+ # сохранить альфа-канальные маски в состояние (для отображения или сохранения видео)
262
+ video_state["alpha_visuals"] = alpha_visuals
263
+
264
  if interactive_state["track_end_number"]:
265
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
266
  video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
 
272
 
273
  video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=float(fps)) # import video_input to name the output video
274
  interactive_state["inference_times"] += 1
275
+ # Дополнительно: альфа-маска-видео
276
+ if "alpha_visuals" in video_state:
277
+ generate_video_from_frames(video_state["alpha_visuals"], output_path="./result/track/{}_alpha.mp4".format(video_state["video_name"].split('.')[0]), fps=float(fps), is_rgba=True)
278
+
279
 
280
  print("Tracking resolution:", following_frames[0].shape)
281
 
 
334
 
335
 
336
  # generate video after vos inference
337
+ def generate_video_from_frames(frames, output_path, fps=30,is_rgba=False):
338
  """
339
  Generates a video from a list of frames.
340
 
 
365
  if not os.path.exists(os.path.dirname(output_path)):
366
  os.makedirs(os.path.dirname(output_path))
367
 
368
+ if is_rgba:
369
+ frames = torch.from_numpy(np.asarray(frames).astype(np.uint8))
370
+ else:
371
+ frames = torch.from_numpy(np.asarray(frames))
372
+
373
  # Write the video
374
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
375
  return output_path