chaerinmin commited on
Commit
973c0b5
·
1 Parent(s): 139a6a5

hf cuda issue

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -663,7 +663,7 @@ def ready_sample(img_cropped, img_original, ex_mask, inpaint_mask, keypts, keypt
663
  img = cv2.resize(img_cropped["background"][..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
664
  else:
665
  img = cv2.resize(img_original[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
666
- sam_predictor.set_image(img)
667
  if keypts is None and keypts_np is not None:
668
  keypts = keypts_np
669
  else:
@@ -724,7 +724,7 @@ def ready_sample(img_cropped, img_original, ex_mask, inpaint_mask, keypts, keypt
724
  img,
725
  keypts,
726
  hand_mask,
727
- device="cuda",
728
  target_size=(256, 256),
729
  latent_size=(32, 32),
730
  ):
@@ -760,11 +760,11 @@ def ready_sample(img_cropped, img_original, ex_mask, inpaint_mask, keypts, keypt
760
  img,
761
  keypts,
762
  hand_mask * (1 - inpaint_mask),
763
- # device=pre_device,
764
  target_size=opts.image_size,
765
  latent_size=opts.latent_size,
766
  )
767
- latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
768
  target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1)
769
  ref_cond = torch.cat([latent, heatmaps, mask], 1)
770
  ref_cond = torch.zeros_like(ref_cond)
 
663
  img = cv2.resize(img_cropped["background"][..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
664
  else:
665
  img = cv2.resize(img_original[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
666
+ sam_predictor.to("cuda").set_image(img)
667
  if keypts is None and keypts_np is not None:
668
  keypts = keypts_np
669
  else:
 
724
  img,
725
  keypts,
726
  hand_mask,
727
+ device,
728
  target_size=(256, 256),
729
  latent_size=(32, 32),
730
  ):
 
760
  img,
761
  keypts,
762
  hand_mask * (1 - inpaint_mask),
763
+ device=pre_device,
764
  target_size=opts.image_size,
765
  latent_size=opts.latent_size,
766
  )
767
+ latent = opts.latent_scaling_factor * autoencoder.encode(image.cuda()).sample()
768
  target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1)
769
  ref_cond = torch.cat([latent, heatmaps, mask], 1)
770
  ref_cond = torch.zeros_like(ref_cond)