Update app.py
Browse files
app.py
CHANGED
@@ -16,26 +16,37 @@ def segment_dress(image):
|
|
16 |
"""Segments the dress from an input image using SAM."""
|
17 |
input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
|
18 |
inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
|
|
|
19 |
with torch.no_grad():
|
20 |
outputs = model(**inputs)
|
|
|
21 |
masks = processor.image_processor.post_process_masks(
|
22 |
-
outputs.pred_masks.cpu(),
|
|
|
|
|
23 |
)
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def warp_design(design, mask, warp_scale):
|
27 |
"""Warp the design using TPS and scale control."""
|
28 |
h, w = mask.shape[:2]
|
29 |
-
|
30 |
# Resize design to match mask dimensions
|
31 |
design_resized = cv2.resize(design, (w, h))
|
32 |
|
33 |
-
# Convert boolean mask to uint8 (0-255)
|
34 |
-
mask = (mask * 255).astype(np.uint8)
|
35 |
-
|
36 |
# Ensure mask is single-channel grayscale
|
37 |
if len(mask.shape) == 3:
|
38 |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# Convert design_resized to 3-channel if it's grayscale
|
41 |
if len(design_resized.shape) == 2:
|
@@ -61,7 +72,6 @@ def blend_images(base, overlay, mask):
|
|
61 |
if mask.dtype == np.bool_:
|
62 |
mask = mask.astype(np.uint8) * 255
|
63 |
|
64 |
-
# Ensure mask is single-channel grayscale
|
65 |
# Ensure mask is single-channel grayscale
|
66 |
if len(mask.shape) == 3 and mask.shape[2] == 3:
|
67 |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
@@ -71,7 +81,7 @@ def blend_images(base, overlay, mask):
|
|
71 |
mask = (mask * 255).astype(np.uint8)
|
72 |
|
73 |
# Compute center of the mask for seamless cloning
|
74 |
-
center =
|
75 |
|
76 |
return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)
|
77 |
|
@@ -79,13 +89,14 @@ def apply_design(image_path, design_path, warp_scale):
|
|
79 |
"""Pipeline to segment, warp, and blend design onto dress."""
|
80 |
image = Image.open(image_path).convert("RGB")
|
81 |
design = cv2.imread(design_path)
|
|
|
82 |
mask = segment_dress(image)
|
83 |
-
|
84 |
if mask is None:
|
85 |
return "Segmentation Failed!"
|
86 |
|
87 |
warped_design = warp_design(design, mask, warp_scale)
|
88 |
blended = blend_images(np.array(image), warped_design, mask)
|
|
|
89 |
return Image.fromarray(blended)
|
90 |
|
91 |
def main(image, design, warp_scale):
|
|
|
16 |
"""Segments the dress from an input image using SAM."""
|
17 |
input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
|
18 |
inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
|
19 |
+
|
20 |
with torch.no_grad():
|
21 |
outputs = model(**inputs)
|
22 |
+
|
23 |
masks = processor.image_processor.post_process_masks(
|
24 |
+
outputs.pred_masks.cpu(),
|
25 |
+
inputs["original_sizes"].cpu(),
|
26 |
+
inputs["reshaped_input_sizes"].cpu()
|
27 |
)
|
28 |
+
|
29 |
+
if masks:
|
30 |
+
mask = masks[0][0].numpy()
|
31 |
+
# Convert boolean mask to uint8 (0-255)
|
32 |
+
mask = (mask * 255).astype(np.uint8)
|
33 |
+
return mask
|
34 |
+
return None
|
35 |
|
36 |
def warp_design(design, mask, warp_scale):
|
37 |
"""Warp the design using TPS and scale control."""
|
38 |
h, w = mask.shape[:2]
|
39 |
+
|
40 |
# Resize design to match mask dimensions
|
41 |
design_resized = cv2.resize(design, (w, h))
|
42 |
|
|
|
|
|
|
|
43 |
# Ensure mask is single-channel grayscale
|
44 |
if len(mask.shape) == 3:
|
45 |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
46 |
+
|
47 |
+
# Ensure mask is uint8
|
48 |
+
if mask.dtype != np.uint8:
|
49 |
+
mask = (mask * 255).astype(np.uint8)
|
50 |
|
51 |
# Convert design_resized to 3-channel if it's grayscale
|
52 |
if len(design_resized.shape) == 2:
|
|
|
72 |
if mask.dtype == np.bool_:
|
73 |
mask = mask.astype(np.uint8) * 255
|
74 |
|
|
|
75 |
# Ensure mask is single-channel grayscale
|
76 |
if len(mask.shape) == 3 and mask.shape[2] == 3:
|
77 |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
|
|
81 |
mask = (mask * 255).astype(np.uint8)
|
82 |
|
83 |
# Compute center of the mask for seamless cloning
|
84 |
+
center = (base.shape[1] // 2, base.shape[0] // 2)
|
85 |
|
86 |
return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)
|
87 |
|
|
|
89 |
"""Pipeline to segment, warp, and blend design onto dress."""
|
90 |
image = Image.open(image_path).convert("RGB")
|
91 |
design = cv2.imread(design_path)
|
92 |
+
|
93 |
mask = segment_dress(image)
|
|
|
94 |
if mask is None:
|
95 |
return "Segmentation Failed!"
|
96 |
|
97 |
warped_design = warp_design(design, mask, warp_scale)
|
98 |
blended = blend_images(np.array(image), warped_design, mask)
|
99 |
+
|
100 |
return Image.fromarray(blended)
|
101 |
|
102 |
def main(image, design, warp_scale):
|