|
import torch |
|
import cv2 |
|
import numpy as np |
|
import gradio as gr |
|
from PIL import Image |
|
from torchvision import transforms |
|
from skimage.restoration import denoise_tv_chambolle |
|
from transformers import SamModel, SamProcessor |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE) |
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
|
|
|
def segment_dress(image): |
|
"""Segments the dress from an input image using SAM.""" |
|
input_points = [[[image.size[0] // 2, image.size[1] // 2]]] |
|
inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
masks = processor.image_processor.post_process_masks( |
|
outputs.pred_masks.cpu(), |
|
inputs["original_sizes"].cpu(), |
|
inputs["reshaped_input_sizes"].cpu() |
|
) |
|
|
|
if masks: |
|
mask = masks[0][0].numpy() |
|
|
|
mask = (mask * 255).astype(np.uint8) |
|
return mask |
|
return None |
|
|
|
def warp_design(design, mask, warp_scale): |
|
"""Warp the design using TPS and scale control.""" |
|
h, w = mask.shape[:2] |
|
|
|
|
|
design_resized = cv2.resize(design, (w, h)) |
|
|
|
|
|
if len(mask.shape) == 3 and mask.shape[2] == 3: |
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
if mask.dtype != np.uint8: |
|
mask = (mask * 255).astype(np.uint8) |
|
|
|
|
|
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
if len(design_resized.shape) == 2: |
|
design_resized = cv2.cvtColor(design_resized, cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
print(f"Design shape: {design_resized.shape}, Mask shape: {mask.shape}, Mask dtype: {mask.dtype}") |
|
|
|
return cv2.bitwise_and(design_resized, design_resized, mask=mask) |
|
|
|
def blend_images(base, overlay, mask): |
|
"""Blends the design onto the dress using seamless cloning.""" |
|
if overlay is None or mask is None: |
|
raise ValueError("Overlay or mask is None, check segmentation and warping.") |
|
|
|
base = np.array(base) |
|
overlay = np.array(overlay) |
|
|
|
|
|
if len(overlay.shape) == 2: |
|
overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2BGR) |
|
if len(base.shape) == 2: |
|
base = cv2.cvtColor(base, cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
if mask.dtype == np.bool_: |
|
mask = mask.astype(np.uint8) * 255 |
|
|
|
|
|
if len(mask.shape) == 3 and mask.shape[2] == 3: |
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
if mask.dtype != np.uint8: |
|
mask = (mask * 255).astype(np.uint8) |
|
|
|
|
|
center = (base.shape[1] // 2, base.shape[0] // 2) |
|
|
|
return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE) |
|
|
|
def apply_design(image_path, design_path, warp_scale): |
|
"""Pipeline to segment, warp, and blend design onto dress.""" |
|
image = Image.open(image_path).convert("RGB") |
|
design = cv2.imread(design_path) |
|
|
|
mask = segment_dress(image) |
|
if mask is None: |
|
return "Segmentation Failed!" |
|
|
|
warped_design = warp_design(design, mask, warp_scale) |
|
blended = blend_images(np.array(image), warped_design, mask) |
|
|
|
return Image.fromarray(blended) |
|
|
|
def main(image, design, warp_scale): |
|
return apply_design(image, design, warp_scale) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
gr.Image(type="filepath", label="Upload Dress Image"), |
|
gr.Image(type="filepath", label="Upload Design Image"), |
|
gr.Slider(0, 100, value=50, label="Warp Scale (%)") |
|
], |
|
outputs=gr.Image(label="Warped Design on Dress"), |
|
title="AI-Powered Dress Designer", |
|
description="Upload a dress image and a design pattern. The AI will warp and blend the design onto the dress while preserving natural folds!" |
|
) |
|
|
|
demo.launch() |