new_mmm / app.py
gaur3009's picture
Update app.py
9e3f3c2 verified
import torch
import cv2
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from skimage.restoration import denoise_tv_chambolle
from transformers import SamModel, SamProcessor
# Load SAM model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
def segment_dress(image):
"""Segments the dress from an input image using SAM."""
input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
if masks:
mask = masks[0][0].numpy()
# Convert boolean mask to uint8 (0-255)
mask = (mask * 255).astype(np.uint8)
return mask
return None
def warp_design(design, mask, warp_scale):
"""Warp the design using TPS and scale control."""
h, w = mask.shape[:2]
# Resize design to match mask dimensions
design_resized = cv2.resize(design, (w, h))
# Convert mask to grayscale if it's not already
if len(mask.shape) == 3 and mask.shape[2] == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
# Ensure mask is uint8 (values between 0-255)
if mask.dtype != np.uint8:
mask = (mask * 255).astype(np.uint8)
# Ensure mask and design are the same size
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
# Ensure design_resized is 3-channel
if len(design_resized.shape) == 2:
design_resized = cv2.cvtColor(design_resized, cv2.COLOR_GRAY2BGR)
# Debugging output
print(f"Design shape: {design_resized.shape}, Mask shape: {mask.shape}, Mask dtype: {mask.dtype}")
return cv2.bitwise_and(design_resized, design_resized, mask=mask)
def blend_images(base, overlay, mask):
"""Blends the design onto the dress using seamless cloning."""
if overlay is None or mask is None:
raise ValueError("Overlay or mask is None, check segmentation and warping.")
base = np.array(base) # Ensure base is a NumPy array
overlay = np.array(overlay)
# Ensure overlay and base are 3-channel images
if len(overlay.shape) == 2:
overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2BGR)
if len(base.shape) == 2:
base = cv2.cvtColor(base, cv2.COLOR_GRAY2BGR)
# Convert mask to uint8 (0-255) if it's a boolean array
if mask.dtype == np.bool_:
mask = mask.astype(np.uint8) * 255
# Ensure mask is single-channel grayscale
if len(mask.shape) == 3 and mask.shape[2] == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
# Convert mask to uint8 if necessary
if mask.dtype != np.uint8:
mask = (mask * 255).astype(np.uint8)
# Compute center of the mask for seamless cloning
center = (base.shape[1] // 2, base.shape[0] // 2)
return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)
def apply_design(image_path, design_path, warp_scale):
"""Pipeline to segment, warp, and blend design onto dress."""
image = Image.open(image_path).convert("RGB")
design = cv2.imread(design_path)
mask = segment_dress(image)
if mask is None:
return "Segmentation Failed!"
warped_design = warp_design(design, mask, warp_scale)
blended = blend_images(np.array(image), warped_design, mask)
return Image.fromarray(blended)
def main(image, design, warp_scale):
return apply_design(image, design, warp_scale)
# Gradio UI
demo = gr.Interface(
fn=main,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Image(type="filepath", label="Upload Design Image"),
gr.Slider(0, 100, value=50, label="Warp Scale (%)")
],
outputs=gr.Image(label="Warped Design on Dress"),
title="AI-Powered Dress Designer",
description="Upload a dress image and a design pattern. The AI will warp and blend the design onto the dress while preserving natural folds!"
)
demo.launch()