goryhon commited on
Commit
e44c7d0
·
verified ·
1 Parent(s): ec03567

Update web-demos/hugging_face/app.py

Browse files
Files changed (1) hide show
  1. web-demos/hugging_face/app.py +2 -16
web-demos/hugging_face/app.py CHANGED
@@ -13,7 +13,6 @@ import torch
13
  import torchvision
14
  import numpy as np
15
  import gradio as gr
16
- from PIL import Image
17
 
18
  from tools.painter import mask_painter
19
  from track_anything import TrackingAnything
@@ -253,14 +252,10 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
253
  template_mask[0][0]=1
254
  operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
255
  # return video_output, video_state, interactive_state, operation_error
256
- masks, logits, painted_images, alpha_visuals = model.generator(images=following_frames, template_mask=template_mask)
257
-
258
  # clear GPU memory
259
  model.cutie.clear_memory()
260
 
261
- # сохранить альфа-канальные маски в состояние (для отображения или сохранения видео)
262
- video_state["alpha_visuals"] = alpha_visuals
263
-
264
  if interactive_state["track_end_number"]:
265
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
266
  video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
@@ -272,10 +267,6 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
272
 
273
  video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=float(fps)) # import video_input to name the output video
274
  interactive_state["inference_times"] += 1
275
- # Дополнительно: альфа-маска-видео
276
- if "alpha_visuals" in video_state:
277
- generate_video_from_frames(video_state["alpha_visuals"], output_path="./result/track/{}_alpha.mp4".format(video_state["video_name"].split('.')[0]), fps=float(fps), is_rgba=True)
278
-
279
 
280
  print("Tracking resolution:", following_frames[0].shape)
281
 
@@ -334,7 +325,7 @@ def inpaint_video(video_state, resize_ratio_number, dilate_radius_number, raft_i
334
 
335
 
336
  # generate video after vos inference
337
- def generate_video_from_frames(frames, output_path, fps=30,is_rgba=False):
338
  """
339
  Generates a video from a list of frames.
340
 
@@ -365,11 +356,6 @@ def generate_video_from_frames(frames, output_path, fps=30,is_rgba=False):
365
  if not os.path.exists(os.path.dirname(output_path)):
366
  os.makedirs(os.path.dirname(output_path))
367
 
368
- if is_rgba:
369
- frames = torch.from_numpy(np.asarray(frames).astype(np.uint8))
370
- else:
371
- frames = torch.from_numpy(np.asarray(frames))
372
-
373
  # Write the video
374
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
375
  return output_path
 
13
  import torchvision
14
  import numpy as np
15
  import gradio as gr
 
16
 
17
  from tools.painter import mask_painter
18
  from track_anything import TrackingAnything
 
252
  template_mask[0][0]=1
253
  operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
254
  # return video_output, video_state, interactive_state, operation_error
255
+ masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
 
256
  # clear GPU memory
257
  model.cutie.clear_memory()
258
 
 
 
 
259
  if interactive_state["track_end_number"]:
260
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
261
  video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
 
267
 
268
  video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=float(fps)) # import video_input to name the output video
269
  interactive_state["inference_times"] += 1
 
 
 
 
270
 
271
  print("Tracking resolution:", following_frames[0].shape)
272
 
 
325
 
326
 
327
  # generate video after vos inference
328
+ def generate_video_from_frames(frames, output_path, fps=30):
329
  """
330
  Generates a video from a list of frames.
331
 
 
356
  if not os.path.exists(os.path.dirname(output_path)):
357
  os.makedirs(os.path.dirname(output_path))
358
 
 
 
 
 
 
359
  # Write the video
360
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
361
  return output_path