File size: 3,749 Bytes
1312174
 
 
12eb9d7
1312174
12eb9d7
 
 
1312174
12eb9d7
 
 
 
 
 
 
 
 
1312174
12eb9d7
 
 
 
 
 
 
 
 
 
ef61ae7
6257f5f
 
14a9987
6257f5f
94ede5e
ef61ae7
 
6257f5f
14a9987
 
 
94ede5e
12eb9d7
 
 
 
1a439bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6257f5f
1a439bd
 
 
 
 
12eb9d7
 
 
 
 
 
 
1312174
12eb9d7
 
1312174
12eb9d7
 
 
1312174
12eb9d7
 
1312174
12eb9d7
 
1312174
12eb9d7
 
 
 
 
 
 
 
1312174
 
12eb9d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import cv2
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from skimage.restoration import denoise_tv_chambolle
from transformers import SamModel, SamProcessor

# Load SAM model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

def segment_dress(image):
    """Segments the dress from an input image using SAM."""
    input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
    inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
    )
    return masks[0][0].numpy() if masks else None

def warp_design(design, mask, warp_scale):
    """Warp the design using TPS and scale control."""
    h, w = mask.shape[:2]
    design_resized = cv2.resize(design, (w, h))

    # Normalize mask, convert to uint8
    scaled_mask = (mask * 255).astype(np.uint8)

    # Ensure single-channel mask
    if len(scaled_mask.shape) == 3 and scaled_mask.shape[2] == 3:
        scaled_mask = cv2.cvtColor(scaled_mask, cv2.COLOR_BGR2GRAY)

    # Resize the mask if needed
    if scaled_mask.shape != (h, w):
        scaled_mask = cv2.resize(scaled_mask, (w, h), interpolation=cv2.INTER_NEAREST)

    print(f"Design Resized Shape: {design_resized.shape}, Mask Shape: {scaled_mask.shape}")
    return cv2.bitwise_and(design_resized, design_resized, mask=scaled_mask)

def blend_images(base, overlay, mask):
    """Blends the design onto the dress using seamless cloning."""
    if overlay is None or mask is None:
        raise ValueError("Overlay or mask is None, check segmentation and warping.")

    base = np.array(base)  # Ensure base is a NumPy array
    overlay = np.array(overlay)

    # Ensure overlay and base are 3-channel images
    if len(overlay.shape) == 2:
        overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2BGR)
    if len(base.shape) == 2:
        base = cv2.cvtColor(base, cv2.COLOR_GRAY2BGR)

    # Ensure mask is single-channel grayscale
    if len(mask.shape) == 3:
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)

    # Convert mask to uint8 if necessary
    if mask.dtype != np.uint8:
        mask = (mask * 255).astype(np.uint8)

    # Compute center of the mask for seamless cloning
    center = tuple(np.array(base.shape[:2]) // 2)

    return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)

def apply_design(image_path, design_path, warp_scale):
    """Pipeline to segment, warp, and blend design onto dress."""
    image = Image.open(image_path).convert("RGB")
    design = cv2.imread(design_path)
    mask = segment_dress(image)
    
    if mask is None:
        return "Segmentation Failed!"
    
    warped_design = warp_design(design, mask, warp_scale)
    blended = blend_images(np.array(image), warped_design, mask)
    return Image.fromarray(blended)

def main(image, design, warp_scale):
    return apply_design(image, design, warp_scale)

# Gradio UI
demo = gr.Interface(
    fn=main,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Image(type="filepath", label="Upload Design Image"),
        gr.Slider(0, 100, value=50, label="Warp Scale (%)")
    ],
    outputs=gr.Image(label="Warped Design on Dress"),
    title="AI-Powered Dress Designer",
    description="Upload a dress image and a design pattern. The AI will warp and blend the design onto the dress while preserving natural folds!"
)

demo.launch()