Spaces:
Running
Running
import torch | |
from torch import autocast | |
from diffusers import StableDiffusionInpaintPipeline | |
import gradio as gr | |
import traceback | |
import base64 | |
from io import BytesIO | |
import os | |
# import sys | |
import PIL | |
import json | |
import requests | |
import logging | |
import time | |
import warnings | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import cv2 | |
warnings.filterwarnings("ignore") | |
# sys.path.insert(1, './parser') | |
# from parser.schp_masker import * | |
from parser.segformer_parser import SegformerParser | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger('clothquill') | |
# Model paths | |
SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes" | |
STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting" | |
# Global variables for models | |
parser = None | |
model = None | |
inpainter = None | |
original_image = None # Store the original uploaded image | |
# Color mapping for different clothing parts | |
CLOTHING_COLORS = { | |
'Background': (0, 0, 0, 0), # Transparent | |
'Hat': (255, 0, 0, 128), # Red | |
'Hair': (0, 255, 0, 128), # Green | |
'Glove': (0, 0, 255, 128), # Blue | |
'Sunglasses': (255, 255, 0, 128), # Yellow | |
'Upper-clothes': (255, 0, 255, 128), # Magenta | |
'Dress': (0, 255, 255, 128), # Cyan | |
'Coat': (128, 0, 0, 128), # Dark Red | |
'Socks': (0, 128, 0, 128), # Dark Green | |
'Pants': (0, 0, 128, 128), # Dark Blue | |
'Jumpsuits': (128, 128, 0, 128), # Dark Yellow | |
'Scarf': (128, 0, 128, 128), # Dark Magenta | |
'Skirt': (0, 128, 128, 128), # Dark Cyan | |
'Face': (192, 192, 192, 128), # Light Gray | |
'Left-arm': (64, 64, 64, 128), # Dark Gray | |
'Right-arm': (64, 64, 64, 128), # Dark Gray | |
'Left-leg': (32, 32, 32, 128), # Very Dark Gray | |
'Right-leg': (32, 32, 32, 128), # Very Dark Gray | |
'Left-shoe': (16, 16, 16, 128), # Almost Black | |
'Right-shoe': (16, 16, 16, 128), # Almost Black | |
} | |
def get_device(): | |
if torch.cuda.is_available(): | |
device = "cuda" | |
logger.info("Using GPU") | |
else: | |
device = "cpu" | |
logger.info("Using CPU") | |
return device | |
def init(): | |
global parser | |
global model | |
global inpainter | |
start_time = time.time() | |
logger.info("Starting application initialization") | |
try: | |
device = get_device() | |
# Check if models directory exists | |
if not os.path.exists("models"): | |
logger.info("Creating models directory...") | |
from download_models import download_models | |
download_models() | |
# Initialize Segformer parser | |
logger.info("Initializing Segformer parser...") | |
parser = SegformerParser(SEGFORMER_MODEL) | |
# Initialize Stable Diffusion model | |
logger.info("Initializing Stable Diffusion model...") | |
model = StableDiffusionInpaintPipeline.from_pretrained( | |
STABLE_DIFFUSION_MODEL, | |
safety_checker=None, | |
revision="fp16" if device == "cuda" else None, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
# Initialize inpainter | |
logger.info("Initializing inpainter...") | |
inpainter = ClothingInpainter(model=model, parser=parser) | |
logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds") | |
except Exception as e: | |
logger.error(f"Error initializing application: {str(e)}") | |
raise e | |
class ClothingInpainter: | |
def __init__(self, model_path=None, model=None, parser=None): | |
self.device = get_device() | |
self.last_mask = None # Store the last generated mask | |
self.original_image = None # Store the original image | |
if model_path is None and model is None: | |
raise ValueError('No model provided!') | |
if model_path is not None: | |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
model_path, | |
safety_checker=None, | |
revision="fp16" if self.device == "cuda" else None, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
).to(self.device) | |
else: | |
self.pipe = model | |
self.parser = parser | |
def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)): | |
x, y = im.size | |
size = max(min_size, x, y) | |
new_im = PIL.Image.new('RGBA', (size, size), fill_color) | |
new_im.paste(im, (int((size - x) / 2), int((size - y) / 2))) | |
return new_im.convert('RGB') | |
def unmake_square(self, init_im, op_im, min_size=256, rs_size=512): | |
x, y = init_im.size | |
size = max(min_size, x, y) | |
factor = rs_size/size | |
return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\ | |
int((size+x) * factor / 2), int((size+y) * factor / 2))) | |
def visualize_segmentation(self, image, masks, selected_parts=None): | |
"""Visualize segmentation with colored overlays for selected parts and gray for unselected.""" | |
# Always use original image if available | |
image_to_use = self.original_image if self.original_image is not None else image | |
# Create a copy of the original image | |
original_size = image_to_use.size | |
vis_image = image_to_use.copy().convert('RGBA') | |
# Create overlay at 512x512 | |
overlay = Image.new('RGBA', (512, 512), (0, 0, 0, 0)) | |
draw = ImageDraw.Draw(overlay) | |
# Draw each mask with its corresponding color | |
for part_name, mask in masks.items(): | |
# Convert part name for color lookup | |
color_key = part_name.replace('-', ' ').title().replace(' ', '-') | |
is_selected = selected_parts and part_name in selected_parts | |
# If selected, use color (with fallback). If unselected, use faint gray | |
if is_selected: | |
color = CLOTHING_COLORS.get(color_key, (255, 0, 255, 128)) # Default to magenta if no color found | |
else: | |
color = (180, 180, 180, 80) # Faint gray for unselected | |
mask_array = np.array(mask) | |
coords = np.where(mask_array > 0) | |
for y, x in zip(coords[0], coords[1]): | |
draw.point((x, y), fill=color) | |
# Resize overlay to match original image size | |
overlay = overlay.resize(original_size, Image.Resampling.LANCZOS) | |
# Composite the overlay onto the original image | |
vis_image = Image.alpha_composite(vis_image, overlay) | |
return vis_image | |
def inpaint(self, prompt, init_image, selected_parts=None, dilation_iterations=2) -> dict: | |
image = self.make_square(init_image).resize((512,512)) | |
if self.parser is not None: | |
masks = self.parser.get_all_masks(image) | |
masks = {k: v.resize((512,512)) for k, v in masks.items()} | |
else: | |
raise ValueError('Image Parser is Missing') | |
logger.info(f'[generated required mask(s) at {time.time()}]') | |
# Create combined mask for selected parts | |
if selected_parts: | |
combined_mask = Image.new('L', (512, 512), 0) | |
for part in selected_parts: | |
if part in masks: | |
mask_array = np.array(masks[part]) | |
kernel = np.ones((5,5), np.uint8) | |
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) | |
dilated_mask = Image.fromarray(dilated_mask) | |
combined_mask = Image.composite( | |
Image.new('L', (512, 512), 255), | |
combined_mask, | |
dilated_mask | |
) | |
else: | |
# If no parts selected, use all clothing parts | |
combined_mask = Image.new('L', (512, 512), 0) | |
for part, mask in masks.items(): | |
if part in ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']: | |
mask_array = np.array(mask) | |
kernel = np.ones((5,5), np.uint8) | |
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) | |
dilated_mask = Image.fromarray(dilated_mask) | |
combined_mask = Image.composite( | |
Image.new('L', (512, 512), 255), | |
combined_mask, | |
dilated_mask | |
) | |
# Run the model | |
guidance_scale=7.5 | |
num_samples = 3 | |
with autocast("cuda"), torch.inference_mode(): | |
images = self.pipe( | |
num_inference_steps = 50, | |
prompt=prompt['pos'], | |
image=image, | |
mask_image=combined_mask, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=num_samples, | |
).images | |
images_output = [] | |
for img in images: | |
ch = PIL.Image.composite(img, image, combined_mask) | |
fin_img = self.unmake_square(init_image, ch) | |
images_output.append(fin_img) | |
return images_output | |
def process_segmentation(image, dilation_iterations=2): | |
try: | |
if image is None: | |
raise gr.Error("Please upload an image") | |
# Store original image | |
inpainter.original_image = image.copy() | |
# Create a processing copy at 512x512 | |
proc_image = image.resize((512, 512), Image.Resampling.LANCZOS) | |
# Get the main mask | |
all_masks = inpainter.parser.get_all_masks(proc_image) | |
if not all_masks: | |
logger.error("No clothing detected in the image") | |
raise gr.Error("No clothing detected in the image. Please try a different image.") | |
inpainter.last_mask = all_masks | |
# Only show main clothing parts for selection | |
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt'] | |
masks = {k: v for k, v in all_masks.items() if k in main_parts} | |
vis_image = inpainter.visualize_segmentation(image, masks, selected_parts=None) | |
detected_parts = [k for k in masks.keys()] | |
return vis_image, gr.update(choices=detected_parts, value=[]) | |
except gr.Error as e: | |
raise e | |
except Exception as e: | |
logger.error(f"Error processing segmentation: {str(e)}") | |
raise gr.Error("Error processing the image. Please try a different image.") | |
def update_dilation(image, selected_parts, dilation_iterations): | |
try: | |
if image is None or inpainter.last_mask is None: | |
return image | |
# Redilate all stored masks | |
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt'] | |
masks = {} | |
for part in main_parts: | |
if part in inpainter.last_mask: | |
mask_array = np.array(inpainter.last_mask[part]) | |
kernel = np.ones((5,5), np.uint8) | |
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) | |
masks[part] = Image.fromarray(dilated_mask) | |
# Use original image for visualization | |
vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts) | |
return vis_image | |
except Exception as e: | |
logger.error(f"Error updating dilation: {str(e)}") | |
return image | |
def process_image(prompt, image, selected_parts, dilation_iterations): | |
start_time = time.time() | |
logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}") | |
try: | |
if image is None: | |
logger.error("No image provided") | |
raise gr.Error("Please upload an image") | |
if not prompt: | |
logger.error("No prompt provided") | |
raise gr.Error("Please enter a prompt") | |
if not selected_parts: | |
logger.error("No parts selected") | |
raise gr.Error("Please select at least one clothing part to modify") | |
prompt_dict = {'pos': prompt} | |
logger.info("Starting inpainting process") | |
# Generate inpainted images | |
# Convert selected_parts to lowercase/dash format | |
selected_parts = [p.lower() for p in selected_parts] | |
images = inpainter.inpaint(prompt_dict, image, selected_parts, dilation_iterations) | |
if not images: | |
logger.error("Inpainting failed to produce results") | |
raise gr.Error("Failed to generate images. Please try again.") | |
logger.info(f"Request processed in {time.time() - start_time:.2f} seconds") | |
return images | |
except Exception as e: | |
logger.error(f"Error processing image: {str(e)}") | |
raise gr.Error(f"Error processing image: {str(e)}") | |
def update_selected_parts(image, selected_parts, dilation_iterations): | |
try: | |
if image is None or inpainter.last_mask is None: | |
return image | |
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt'] | |
masks = {} | |
for part in main_parts: | |
if part in inpainter.last_mask: | |
mask_array = np.array(inpainter.last_mask[part]) | |
kernel = np.ones((5,5), np.uint8) | |
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) | |
masks[part] = Image.fromarray(dilated_mask) | |
# Lowercase the selected_parts for comparison | |
selected_parts = [p.lower() for p in selected_parts] if selected_parts else [] | |
# Use original image for visualization | |
vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts) | |
return vis_image | |
except Exception as e: | |
logger.error(f"Error updating selected parts: {str(e)}") | |
return image | |
# Initialize the model | |
init() | |
# Create Gradio interface | |
with gr.Blocks(title="ClothQuill - AI Clothing Inpainting") as demo: | |
gr.Markdown("# ClothQuill - AI Clothing Inpainting") | |
gr.Markdown("Upload an image to see segmented clothing parts, then select parts to modify and describe your changes") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
type="pil", | |
label="Upload Image", | |
scale=1, # This ensures the image maintains its aspect ratio | |
height=None # Allow dynamic height based on content | |
) | |
dilation_slider = gr.Slider( | |
minimum=0, | |
maximum=5, | |
value=2, | |
step=1, | |
label="Mask Dilation", | |
info="Adjust the mask dilation to control the area of modification" | |
) | |
selected_parts = gr.CheckboxGroup( | |
choices=[], | |
label="Select parts to modify", | |
value=[] | |
) | |
prompt = gr.Textbox( | |
label="Describe the clothing you want to generate", | |
placeholder="e.g., A stylish black leather jacket" | |
) | |
generate_btn = gr.Button("Generate") | |
with gr.Column(): | |
gallery = gr.Gallery( | |
label="Generated Results", | |
show_label=False, | |
columns=2, | |
height=None, # Allow dynamic height | |
object_fit="contain" # Maintain aspect ratio | |
) | |
# Add event handler for image upload | |
input_image.upload( | |
fn=process_segmentation, | |
inputs=[input_image, dilation_slider], | |
outputs=[input_image, selected_parts] | |
) | |
# Add event handler for dilation changes | |
dilation_slider.change( | |
fn=update_dilation, | |
inputs=[input_image, selected_parts,dilation_slider], | |
outputs=input_image | |
) | |
# Add event handler for generation | |
generate_btn.click( | |
fn=process_image, | |
inputs=[prompt, input_image, selected_parts, dilation_slider], | |
outputs=gallery | |
) | |
# Add event handler for part selection changes | |
selected_parts.change( | |
fn=update_selected_parts, | |
inputs=[input_image, selected_parts, dilation_slider], | |
outputs=input_image | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |