David Krajewski commited on
Commit
762836f
·
1 Parent(s): 27f9f79
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -236,6 +236,7 @@ class Drag:
236
  self.width = width
237
  self.model_length = model_length
238
 
 
239
  def get_cmp_flow(self, frames, sparse_optical_flow, mask, brush_mask=None):
240
 
241
  '''
@@ -243,6 +244,8 @@ class Drag:
243
  sparse_optical_flow: [b, 13, 2, 384, 384] (-384, 384) tensor
244
  mask: [b, 13, 2, 384, 384] {0, 1} tensor
245
  '''
 
 
246
 
247
  b, t, c, h, w = frames.shape
248
  assert h == 384 and w == 384
@@ -486,9 +489,8 @@ class Drag:
486
  viz_esti_flows = flow_to_image(controlnet_flow) # [h, w, c]
487
 
488
  return viz_esti_flows
489
-
490
- @spaces.GPU(enable_queue=True, duration=240)
491
- def run(self, first_frame_path, tracking_points, inference_batch_size, motion_brush_mask, motion_brush_viz, ctrl_scale):
492
  svd_ckpt = "./ckpts/stable-video-diffusion-img2vid-xt-1-1"
493
  mofa_ckpt = "./ckpts/controlnet/ckpts/controlnet"
494
 
@@ -502,6 +504,11 @@ class Drag:
502
  self.pipeline = self.pipeline.to("cuda:0")
503
  self.cmp = self.cmp.to("cuda:0")
504
 
 
 
 
 
 
505
  original_width, original_height = self.width, self.height
506
 
507
  input_all_points = tracking_points.constructor_args['value']
 
236
  self.width = width
237
  self.model_length = model_length
238
 
239
+ @spaces.GPU(enable_queue=True, duration=240)
240
  def get_cmp_flow(self, frames, sparse_optical_flow, mask, brush_mask=None):
241
 
242
  '''
 
244
  sparse_optical_flow: [b, 13, 2, 384, 384] (-384, 384) tensor
245
  mask: [b, 13, 2, 384, 384] {0, 1} tensor
246
  '''
247
+ if not self.cmp or not self.pipeline:
248
+ self.run_model_init()
249
 
250
  b, t, c, h, w = frames.shape
251
  assert h == 384 and w == 384
 
489
  viz_esti_flows = flow_to_image(controlnet_flow) # [h, w, c]
490
 
491
  return viz_esti_flows
492
+
493
+ def run_model_init(self):
 
494
  svd_ckpt = "./ckpts/stable-video-diffusion-img2vid-xt-1-1"
495
  mofa_ckpt = "./ckpts/controlnet/ckpts/controlnet"
496
 
 
504
  self.pipeline = self.pipeline.to("cuda:0")
505
  self.cmp = self.cmp.to("cuda:0")
506
 
507
+ @spaces.GPU(enable_queue=True, duration=240)
508
+ def run(self, first_frame_path, tracking_points, inference_batch_size, motion_brush_mask, motion_brush_viz, ctrl_scale):
509
+ if not self.cmp or not self.pipeline:
510
+ self.run_model_init()
511
+
512
  original_width, original_height = self.width, self.height
513
 
514
  input_all_points = tracking_points.constructor_args['value']