File size: 2,959 Bytes
1312174
 
 
12eb9d7
1312174
12eb9d7
 
 
1312174
12eb9d7
 
 
 
 
 
 
 
 
1312174
12eb9d7
 
 
 
 
 
 
 
 
 
ef61ae7
14a9987
12eb9d7
14a9987
 
 
ef61ae7
 
14a9987
 
 
 
12eb9d7
 
 
 
 
 
 
 
 
 
 
 
1312174
12eb9d7
 
1312174
12eb9d7
 
 
1312174
12eb9d7
 
1312174
12eb9d7
 
1312174
12eb9d7
 
 
 
 
 
 
 
1312174
 
12eb9d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import cv2
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from skimage.restoration import denoise_tv_chambolle
from transformers import SamModel, SamProcessor

# Load SAM model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

def segment_dress(image):
    """Segments the dress from an input image using SAM."""
    input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
    inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
    )
    return masks[0][0].numpy() if masks else None

def warp_design(design, mask, warp_scale):
    """Warp the design using TPS and scale control."""
    h, w = mask.shape[:2]
    design_resized = cv2.resize(design, (w, h))

    # Normalize mask and convert to uint8
    scaled_mask = (mask * 255 * (warp_scale / 100)).astype(np.uint8)

    # Ensure the mask is single-channel and same size as design
    if len(scaled_mask.shape) == 3:
        scaled_mask = cv2.cvtColor(scaled_mask, cv2.COLOR_BGR2GRAY)

    # Resize the mask to match design_resized if needed
    if scaled_mask.shape != (h, w):
        scaled_mask = cv2.resize(scaled_mask, (w, h), interpolation=cv2.INTER_NEAREST)

    return cv2.bitwise_and(design_resized, design_resized, mask=scaled_mask)

def blend_images(base, overlay, mask):
    """Blends the design onto the dress using seamless cloning."""
    center = tuple(np.array(base.shape[:2]) // 2)
    return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)

def apply_design(image_path, design_path, warp_scale):
    """Pipeline to segment, warp, and blend design onto dress."""
    image = Image.open(image_path).convert("RGB")
    design = cv2.imread(design_path)
    mask = segment_dress(image)
    
    if mask is None:
        return "Segmentation Failed!"
    
    warped_design = warp_design(design, mask, warp_scale)
    blended = blend_images(np.array(image), warped_design, mask)
    return Image.fromarray(blended)

def main(image, design, warp_scale):
    return apply_design(image, design, warp_scale)

# Gradio UI
demo = gr.Interface(
    fn=main,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Image(type="filepath", label="Upload Design Image"),
        gr.Slider(0, 100, value=50, label="Warp Scale (%)")
    ],
    outputs=gr.Image(label="Warped Design on Dress"),
    title="AI-Powered Dress Designer",
    description="Upload a dress image and a design pattern. The AI will warp and blend the design onto the dress while preserving natural folds!"
)

demo.launch()