David Krajewski commited on
Commit
07f3992
·
1 Parent(s): 55d0417

Chnaged cuda

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -29,6 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
29
  from utils.flow_viz import flow_to_image
30
  from utils.utils import split_filename, image2arr, image2pil, ensure_dirname
31
  from huggingface_hub import login, hf_hub_download, snapshot_download
 
32
 
33
 
34
  output_dir_video = "./outputs/videos"
@@ -120,7 +121,7 @@ def init_models(pretrained_model_name_or_path, resume_from_checkpoint, weight_dt
120
  cmp = CMP_demo(
121
  './models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml',
122
  42000
123
- ).to(device)
124
  cmp.requires_grad_(False)
125
 
126
  # Freeze vae and image_encoder
@@ -497,8 +498,11 @@ class Drag:
497
 
498
  return viz_esti_flows
499
 
 
500
  def run(self, first_frame_path, tracking_points, inference_batch_size, motion_brush_mask, motion_brush_viz, ctrl_scale):
501
-
 
 
502
  original_width, original_height = self.width, self.height
503
 
504
  input_all_points = tracking_points.constructor_args['value']
 
29
  from utils.flow_viz import flow_to_image
30
  from utils.utils import split_filename, image2arr, image2pil, ensure_dirname
31
  from huggingface_hub import login, hf_hub_download, snapshot_download
32
+ import spaces
33
 
34
 
35
  output_dir_video = "./outputs/videos"
 
121
  cmp = CMP_demo(
122
  './models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml',
123
  42000
124
+ )
125
  cmp.requires_grad_(False)
126
 
127
  # Freeze vae and image_encoder
 
498
 
499
  return viz_esti_flows
500
 
501
+ @spaces.GPU(enable_queue=True, duration=240)
502
  def run(self, first_frame_path, tracking_points, inference_batch_size, motion_brush_mask, motion_brush_viz, ctrl_scale):
503
+ self.pipeline = self.pipeline.to("cuda:0")
504
+ self.cmp = self.cmp.to("cuda:0")
505
+
506
  original_width, original_height = self.width, self.height
507
 
508
  input_all_points = tracking_points.constructor_args['value']