import torch import cv2 import numpy as np import gradio as gr from PIL import Image from torchvision import transforms from skimage.restoration import denoise_tv_chambolle from transformers import SamModel, SamProcessor # Load SAM model DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") def segment_dress(image): """Segments the dress from an input image using SAM.""" input_points = [[[image.size[0] // 2, image.size[1] // 2]]] inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) if masks: mask = masks[0][0].numpy() # Convert boolean mask to uint8 (0-255) mask = (mask * 255).astype(np.uint8) return mask return None def warp_design(design, mask, warp_scale): """Warp the design using TPS and scale control.""" h, w = mask.shape[:2] # Resize design to match mask dimensions design_resized = cv2.resize(design, (w, h)) # Convert mask to grayscale if it's not already if len(mask.shape) == 3 and mask.shape[2] == 3: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) # Ensure mask is uint8 (values between 0-255) if mask.dtype != np.uint8: mask = (mask * 255).astype(np.uint8) # Ensure mask and design are the same size mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) # Ensure design_resized is 3-channel if len(design_resized.shape) == 2: design_resized = cv2.cvtColor(design_resized, cv2.COLOR_GRAY2BGR) # Debugging output print(f"Design shape: {design_resized.shape}, Mask shape: {mask.shape}, Mask dtype: {mask.dtype}") return cv2.bitwise_and(design_resized, design_resized, mask=mask) def blend_images(base, overlay, mask): """Blends the design onto the dress using seamless cloning.""" if overlay is None or mask is None: raise ValueError("Overlay or mask is None, check segmentation and warping.") base = np.array(base) # Ensure base is a NumPy array overlay = np.array(overlay) # Ensure overlay and base are 3-channel images if len(overlay.shape) == 2: overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2BGR) if len(base.shape) == 2: base = cv2.cvtColor(base, cv2.COLOR_GRAY2BGR) # Convert mask to uint8 (0-255) if it's a boolean array if mask.dtype == np.bool_: mask = mask.astype(np.uint8) * 255 # Ensure mask is single-channel grayscale if len(mask.shape) == 3 and mask.shape[2] == 3: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) # Convert mask to uint8 if necessary if mask.dtype != np.uint8: mask = (mask * 255).astype(np.uint8) # Compute center of the mask for seamless cloning center = (base.shape[1] // 2, base.shape[0] // 2) return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE) def apply_design(image_path, design_path, warp_scale): """Pipeline to segment, warp, and blend design onto dress.""" image = Image.open(image_path).convert("RGB") design = cv2.imread(design_path) mask = segment_dress(image) if mask is None: return "Segmentation Failed!" warped_design = warp_design(design, mask, warp_scale) blended = blend_images(np.array(image), warped_design, mask) return Image.fromarray(blended) def main(image, design, warp_scale): return apply_design(image, design, warp_scale) # Gradio UI demo = gr.Interface( fn=main, inputs=[ gr.Image(type="filepath", label="Upload Dress Image"), gr.Image(type="filepath", label="Upload Design Image"), gr.Slider(0, 100, value=50, label="Warp Scale (%)") ], outputs=gr.Image(label="Warped Design on Dress"), title="AI-Powered Dress Designer", description="Upload a dress image and a design pattern. The AI will warp and blend the design onto the dress while preserving natural folds!" ) demo.launch()