File size: 4,293 Bytes
1312174
 
 
12eb9d7
1312174
12eb9d7
 
 
1312174
12eb9d7
 
 
 
 
 
 
 
 
3361c9c
1312174
12eb9d7
3361c9c
12eb9d7
3361c9c
 
 
12eb9d7
3361c9c
 
 
 
 
 
 
12eb9d7
 
 
 
3361c9c
24d82c0
12eb9d7
ef61ae7
9e3f3c2
f229f14
24d82c0
f229f14
9e3f3c2
3361c9c
 
14a9987
9e3f3c2
 
 
 
24d82c0
 
ef61ae7
9e3f3c2
 
 
24d82c0
12eb9d7
 
 
1a439bd
 
 
 
 
 
 
 
 
 
 
 
d33a943
 
 
 
96fe22f
 
1a439bd
96fe22f
1a439bd
6257f5f
1a439bd
 
 
3361c9c
1a439bd
12eb9d7
 
 
 
 
 
3361c9c
12eb9d7
 
 
1312174
12eb9d7
 
3361c9c
12eb9d7
1312174
12eb9d7
 
1312174
12eb9d7
 
1312174
12eb9d7
 
 
 
 
 
 
 
1312174
 
12eb9d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import cv2
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from skimage.restoration import denoise_tv_chambolle
from transformers import SamModel, SamProcessor

# Load SAM model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

def segment_dress(image):
    """Segments the dress from an input image using SAM."""
    input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
    inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
    
    with torch.no_grad():
        outputs = model(**inputs)

    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), 
        inputs["original_sizes"].cpu(), 
        inputs["reshaped_input_sizes"].cpu()
    )

    if masks:
        mask = masks[0][0].numpy()
        # Convert boolean mask to uint8 (0-255)
        mask = (mask * 255).astype(np.uint8)
        return mask
    return None

def warp_design(design, mask, warp_scale):
    """Warp the design using TPS and scale control."""
    h, w = mask.shape[:2]
    
    # Resize design to match mask dimensions
    design_resized = cv2.resize(design, (w, h))

    # Convert mask to grayscale if it's not already
    if len(mask.shape) == 3 and mask.shape[2] == 3:
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)

    # Ensure mask is uint8 (values between 0-255)
    if mask.dtype != np.uint8:
        mask = (mask * 255).astype(np.uint8)

    # Ensure mask and design are the same size
    mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)

    # Ensure design_resized is 3-channel
    if len(design_resized.shape) == 2:
        design_resized = cv2.cvtColor(design_resized, cv2.COLOR_GRAY2BGR)

    # Debugging output
    print(f"Design shape: {design_resized.shape}, Mask shape: {mask.shape}, Mask dtype: {mask.dtype}")

    return cv2.bitwise_and(design_resized, design_resized, mask=mask)

def blend_images(base, overlay, mask):
    """Blends the design onto the dress using seamless cloning."""
    if overlay is None or mask is None:
        raise ValueError("Overlay or mask is None, check segmentation and warping.")

    base = np.array(base)  # Ensure base is a NumPy array
    overlay = np.array(overlay)

    # Ensure overlay and base are 3-channel images
    if len(overlay.shape) == 2:
        overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2BGR)
    if len(base.shape) == 2:
        base = cv2.cvtColor(base, cv2.COLOR_GRAY2BGR)

    # Convert mask to uint8 (0-255) if it's a boolean array
    if mask.dtype == np.bool_:
        mask = mask.astype(np.uint8) * 255

    # Ensure mask is single-channel grayscale
    if len(mask.shape) == 3 and mask.shape[2] == 3:
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    
    # Convert mask to uint8 if necessary
    if mask.dtype != np.uint8:
        mask = (mask * 255).astype(np.uint8)

    # Compute center of the mask for seamless cloning
    center = (base.shape[1] // 2, base.shape[0] // 2)

    return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)

def apply_design(image_path, design_path, warp_scale):
    """Pipeline to segment, warp, and blend design onto dress."""
    image = Image.open(image_path).convert("RGB")
    design = cv2.imread(design_path)

    mask = segment_dress(image)
    if mask is None:
        return "Segmentation Failed!"
    
    warped_design = warp_design(design, mask, warp_scale)
    blended = blend_images(np.array(image), warped_design, mask)
    
    return Image.fromarray(blended)

def main(image, design, warp_scale):
    return apply_design(image, design, warp_scale)

# Gradio UI
demo = gr.Interface(
    fn=main,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Image(type="filepath", label="Upload Design Image"),
        gr.Slider(0, 100, value=50, label="Warp Scale (%)")
    ],
    outputs=gr.Image(label="Warped Design on Dress"),
    title="AI-Powered Dress Designer",
    description="Upload a dress image and a design pattern. The AI will warp and blend the design onto the dress while preserving natural folds!"
)

demo.launch()