File size: 4,293 Bytes
1312174 12eb9d7 1312174 12eb9d7 1312174 12eb9d7 3361c9c 1312174 12eb9d7 3361c9c 12eb9d7 3361c9c 12eb9d7 3361c9c 12eb9d7 3361c9c 24d82c0 12eb9d7 ef61ae7 9e3f3c2 f229f14 24d82c0 f229f14 9e3f3c2 3361c9c 14a9987 9e3f3c2 24d82c0 ef61ae7 9e3f3c2 24d82c0 12eb9d7 1a439bd d33a943 96fe22f 1a439bd 96fe22f 1a439bd 6257f5f 1a439bd 3361c9c 1a439bd 12eb9d7 3361c9c 12eb9d7 1312174 12eb9d7 3361c9c 12eb9d7 1312174 12eb9d7 1312174 12eb9d7 1312174 12eb9d7 1312174 12eb9d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import torch
import cv2
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from skimage.restoration import denoise_tv_chambolle
from transformers import SamModel, SamProcessor
# Load SAM model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
def segment_dress(image):
"""Segments the dress from an input image using SAM."""
input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
if masks:
mask = masks[0][0].numpy()
# Convert boolean mask to uint8 (0-255)
mask = (mask * 255).astype(np.uint8)
return mask
return None
def warp_design(design, mask, warp_scale):
"""Warp the design using TPS and scale control."""
h, w = mask.shape[:2]
# Resize design to match mask dimensions
design_resized = cv2.resize(design, (w, h))
# Convert mask to grayscale if it's not already
if len(mask.shape) == 3 and mask.shape[2] == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
# Ensure mask is uint8 (values between 0-255)
if mask.dtype != np.uint8:
mask = (mask * 255).astype(np.uint8)
# Ensure mask and design are the same size
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
# Ensure design_resized is 3-channel
if len(design_resized.shape) == 2:
design_resized = cv2.cvtColor(design_resized, cv2.COLOR_GRAY2BGR)
# Debugging output
print(f"Design shape: {design_resized.shape}, Mask shape: {mask.shape}, Mask dtype: {mask.dtype}")
return cv2.bitwise_and(design_resized, design_resized, mask=mask)
def blend_images(base, overlay, mask):
"""Blends the design onto the dress using seamless cloning."""
if overlay is None or mask is None:
raise ValueError("Overlay or mask is None, check segmentation and warping.")
base = np.array(base) # Ensure base is a NumPy array
overlay = np.array(overlay)
# Ensure overlay and base are 3-channel images
if len(overlay.shape) == 2:
overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2BGR)
if len(base.shape) == 2:
base = cv2.cvtColor(base, cv2.COLOR_GRAY2BGR)
# Convert mask to uint8 (0-255) if it's a boolean array
if mask.dtype == np.bool_:
mask = mask.astype(np.uint8) * 255
# Ensure mask is single-channel grayscale
if len(mask.shape) == 3 and mask.shape[2] == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
# Convert mask to uint8 if necessary
if mask.dtype != np.uint8:
mask = (mask * 255).astype(np.uint8)
# Compute center of the mask for seamless cloning
center = (base.shape[1] // 2, base.shape[0] // 2)
return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)
def apply_design(image_path, design_path, warp_scale):
"""Pipeline to segment, warp, and blend design onto dress."""
image = Image.open(image_path).convert("RGB")
design = cv2.imread(design_path)
mask = segment_dress(image)
if mask is None:
return "Segmentation Failed!"
warped_design = warp_design(design, mask, warp_scale)
blended = blend_images(np.array(image), warped_design, mask)
return Image.fromarray(blended)
def main(image, design, warp_scale):
return apply_design(image, design, warp_scale)
# Gradio UI
demo = gr.Interface(
fn=main,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Image(type="filepath", label="Upload Design Image"),
gr.Slider(0, 100, value=50, label="Warp Scale (%)")
],
outputs=gr.Image(label="Warped Design on Dress"),
title="AI-Powered Dress Designer",
description="Upload a dress image and a design pattern. The AI will warp and blend the design onto the dress while preserving natural folds!"
)
demo.launch() |