samarth-ht commited on
Commit
7143bfc
·
1 Parent(s): 10ac76e

bug fixeing

Browse files
scripts/inference.py CHANGED
@@ -86,6 +86,7 @@ def main(config, args):
86
  height=config.data.resolution,
87
  mask_path=args.mask_path,
88
  )
 
89
 
90
 
91
  if __name__ == "__main__":
 
86
  height=config.data.resolution,
87
  mask_path=args.mask_path,
88
  )
89
+ print("Inference completed successfully.", args.mask_path)
90
 
91
 
92
  if __name__ == "__main__":
soundimage/pipelines/lipsync_pipeline.py CHANGED
@@ -318,6 +318,7 @@ class LipsyncPipeline(DiffusionPipeline):
318
  # 0. Define call parameters
319
  batch_size = 1
320
  device = self._execution_device
 
321
  self.image_processor = ImageProcessor(height, mask=mask, device="cuda", mask_path=mask_path)
322
  self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
323
 
 
318
  # 0. Define call parameters
319
  batch_size = 1
320
  device = self._execution_device
321
+ print(f"Loading fixed mask from {mask_path}")
322
  self.image_processor = ImageProcessor(height, mask=mask, device="cuda", mask_path=mask_path)
323
  self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
324
 
soundimage/utils/image_processor.py CHANGED
@@ -28,12 +28,7 @@ https://stackoverflow.com/questions/23853632/which-kind-of-interpolation-best-fo
28
  """
29
 
30
 
31
- def load_fixed_mask(resolution: int, mask_path: str) -> torch.Tensor:
32
- mask_image = cv2.imread(mask_path)
33
- mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
34
- mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
35
- mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
36
- return mask_image
37
 
38
 
39
  class ImageProcessor:
@@ -53,6 +48,7 @@ class ImageProcessor:
53
  self.restorer = AlignRestore()
54
 
55
  if mask_image is None:
 
56
  self.mask_image = self.load_fixed_mask(resolution, mask_path)
57
  else:
58
  self.mask_image = mask_image
@@ -66,8 +62,14 @@ class ImageProcessor:
66
  # self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
67
  self.face_mesh = None
68
  self.fa = None
69
-
70
-
 
 
 
 
 
 
71
 
72
  def detect_facial_landmarks(self, image: np.ndarray):
73
  height, width, _ = image.shape
 
28
  """
29
 
30
 
31
+
 
 
 
 
 
32
 
33
 
34
  class ImageProcessor:
 
48
  self.restorer = AlignRestore()
49
 
50
  if mask_image is None:
51
+ print(f"Loading fixed mask from {mask_path}")
52
  self.mask_image = self.load_fixed_mask(resolution, mask_path)
53
  else:
54
  self.mask_image = mask_image
 
62
  # self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
63
  self.face_mesh = None
64
  self.fa = None
65
+
66
+ def load_fixed_mask(resolution: int, mask_path: str) -> torch.Tensor:
67
+ print(f"Loading fixed mask from {mask_path}")
68
+ mask_image = cv2.imread(mask_path)
69
+ mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
70
+ mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
71
+ mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
72
+ return mask_image
73
 
74
  def detect_facial_landmarks(self, image: np.ndarray):
75
  height, width, _ = image.shape