gaur3009 commited on
Commit
3361c9c
·
verified ·
1 Parent(s): 1f04d42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -9
app.py CHANGED
@@ -16,26 +16,37 @@ def segment_dress(image):
16
  """Segments the dress from an input image using SAM."""
17
  input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
18
  inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
 
19
  with torch.no_grad():
20
  outputs = model(**inputs)
 
21
  masks = processor.image_processor.post_process_masks(
22
- outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
 
 
23
  )
24
- return masks[0][0].numpy() if masks else None
 
 
 
 
 
 
25
 
26
  def warp_design(design, mask, warp_scale):
27
  """Warp the design using TPS and scale control."""
28
  h, w = mask.shape[:2]
29
-
30
  # Resize design to match mask dimensions
31
  design_resized = cv2.resize(design, (w, h))
32
 
33
- # Convert boolean mask to uint8 (0-255)
34
- mask = (mask * 255).astype(np.uint8)
35
-
36
  # Ensure mask is single-channel grayscale
37
  if len(mask.shape) == 3:
38
  mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
 
 
 
 
39
 
40
  # Convert design_resized to 3-channel if it's grayscale
41
  if len(design_resized.shape) == 2:
@@ -61,7 +72,6 @@ def blend_images(base, overlay, mask):
61
  if mask.dtype == np.bool_:
62
  mask = mask.astype(np.uint8) * 255
63
 
64
- # Ensure mask is single-channel grayscale
65
  # Ensure mask is single-channel grayscale
66
  if len(mask.shape) == 3 and mask.shape[2] == 3:
67
  mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
@@ -71,7 +81,7 @@ def blend_images(base, overlay, mask):
71
  mask = (mask * 255).astype(np.uint8)
72
 
73
  # Compute center of the mask for seamless cloning
74
- center = tuple(np.array(base.shape[:2]) // 2)
75
 
76
  return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)
77
 
@@ -79,13 +89,14 @@ def apply_design(image_path, design_path, warp_scale):
79
  """Pipeline to segment, warp, and blend design onto dress."""
80
  image = Image.open(image_path).convert("RGB")
81
  design = cv2.imread(design_path)
 
82
  mask = segment_dress(image)
83
-
84
  if mask is None:
85
  return "Segmentation Failed!"
86
 
87
  warped_design = warp_design(design, mask, warp_scale)
88
  blended = blend_images(np.array(image), warped_design, mask)
 
89
  return Image.fromarray(blended)
90
 
91
  def main(image, design, warp_scale):
 
16
  """Segments the dress from an input image using SAM."""
17
  input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
18
  inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
19
+
20
  with torch.no_grad():
21
  outputs = model(**inputs)
22
+
23
  masks = processor.image_processor.post_process_masks(
24
+ outputs.pred_masks.cpu(),
25
+ inputs["original_sizes"].cpu(),
26
+ inputs["reshaped_input_sizes"].cpu()
27
  )
28
+
29
+ if masks:
30
+ mask = masks[0][0].numpy()
31
+ # Convert boolean mask to uint8 (0-255)
32
+ mask = (mask * 255).astype(np.uint8)
33
+ return mask
34
+ return None
35
 
36
  def warp_design(design, mask, warp_scale):
37
  """Warp the design using TPS and scale control."""
38
  h, w = mask.shape[:2]
39
+
40
  # Resize design to match mask dimensions
41
  design_resized = cv2.resize(design, (w, h))
42
 
 
 
 
43
  # Ensure mask is single-channel grayscale
44
  if len(mask.shape) == 3:
45
  mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
46
+
47
+ # Ensure mask is uint8
48
+ if mask.dtype != np.uint8:
49
+ mask = (mask * 255).astype(np.uint8)
50
 
51
  # Convert design_resized to 3-channel if it's grayscale
52
  if len(design_resized.shape) == 2:
 
72
  if mask.dtype == np.bool_:
73
  mask = mask.astype(np.uint8) * 255
74
 
 
75
  # Ensure mask is single-channel grayscale
76
  if len(mask.shape) == 3 and mask.shape[2] == 3:
77
  mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
 
81
  mask = (mask * 255).astype(np.uint8)
82
 
83
  # Compute center of the mask for seamless cloning
84
+ center = (base.shape[1] // 2, base.shape[0] // 2)
85
 
86
  return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)
87
 
 
89
  """Pipeline to segment, warp, and blend design onto dress."""
90
  image = Image.open(image_path).convert("RGB")
91
  design = cv2.imread(design_path)
92
+
93
  mask = segment_dress(image)
 
94
  if mask is None:
95
  return "Segmentation Failed!"
96
 
97
  warped_design = warp_design(design, mask, warp_scale)
98
  blended = blend_images(np.array(image), warped_design, mask)
99
+
100
  return Image.fromarray(blended)
101
 
102
  def main(image, design, warp_scale):