import torch from torch import autocast from diffusers import StableDiffusionInpaintPipeline import gradio as gr import traceback import base64 from io import BytesIO import os # import sys import PIL import json import requests import logging import time import warnings import numpy as np from PIL import Image, ImageDraw import cv2 warnings.filterwarnings("ignore") # sys.path.insert(1, './parser') # from parser.schp_masker import * from parser.segformer_parser import SegformerParser # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger('clothquill') # Model paths SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes" STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting" # Global variables for models parser = None model = None inpainter = None original_image = None # Store the original uploaded image # Color mapping for different clothing parts CLOTHING_COLORS = { 'Background': (0, 0, 0, 0), # Transparent 'Hat': (255, 0, 0, 128), # Red 'Hair': (0, 255, 0, 128), # Green 'Glove': (0, 0, 255, 128), # Blue 'Sunglasses': (255, 255, 0, 128), # Yellow 'Upper-clothes': (255, 0, 255, 128), # Magenta 'Dress': (0, 255, 255, 128), # Cyan 'Coat': (128, 0, 0, 128), # Dark Red 'Socks': (0, 128, 0, 128), # Dark Green 'Pants': (0, 0, 128, 128), # Dark Blue 'Jumpsuits': (128, 128, 0, 128), # Dark Yellow 'Scarf': (128, 0, 128, 128), # Dark Magenta 'Skirt': (0, 128, 128, 128), # Dark Cyan 'Face': (192, 192, 192, 128), # Light Gray 'Left-arm': (64, 64, 64, 128), # Dark Gray 'Right-arm': (64, 64, 64, 128), # Dark Gray 'Left-leg': (32, 32, 32, 128), # Very Dark Gray 'Right-leg': (32, 32, 32, 128), # Very Dark Gray 'Left-shoe': (16, 16, 16, 128), # Almost Black 'Right-shoe': (16, 16, 16, 128), # Almost Black } def get_device(): if torch.cuda.is_available(): device = "cuda" logger.info("Using GPU") else: device = "cpu" logger.info("Using CPU") return device def init(): global parser global model global inpainter start_time = time.time() logger.info("Starting application initialization") try: device = get_device() # Check if models directory exists if not os.path.exists("models"): logger.info("Creating models directory...") from download_models import download_models download_models() # Initialize Segformer parser logger.info("Initializing Segformer parser...") parser = SegformerParser(SEGFORMER_MODEL) # Initialize Stable Diffusion model logger.info("Initializing Stable Diffusion model...") model = StableDiffusionInpaintPipeline.from_pretrained( STABLE_DIFFUSION_MODEL, safety_checker=None, revision="fp16" if device == "cuda" else None, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) # Initialize inpainter logger.info("Initializing inpainter...") inpainter = ClothingInpainter(model=model, parser=parser) logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds") except Exception as e: logger.error(f"Error initializing application: {str(e)}") raise e class ClothingInpainter: def __init__(self, model_path=None, model=None, parser=None): self.device = get_device() self.last_mask = None # Store the last generated mask self.original_image = None # Store the original image if model_path is None and model is None: raise ValueError('No model provided!') if model_path is not None: self.pipe = StableDiffusionInpaintPipeline.from_pretrained( model_path, safety_checker=None, revision="fp16" if self.device == "cuda" else None, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) else: self.pipe = model self.parser = parser def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)): x, y = im.size size = max(min_size, x, y) new_im = PIL.Image.new('RGBA', (size, size), fill_color) new_im.paste(im, (int((size - x) / 2), int((size - y) / 2))) return new_im.convert('RGB') def unmake_square(self, init_im, op_im, min_size=256, rs_size=512): x, y = init_im.size size = max(min_size, x, y) factor = rs_size/size return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\ int((size+x) * factor / 2), int((size+y) * factor / 2))) def visualize_segmentation(self, image, masks, selected_parts=None): """Visualize segmentation with colored overlays for selected parts and gray for unselected.""" # Always use original image if available image_to_use = self.original_image if self.original_image is not None else image # Create a copy of the original image original_size = image_to_use.size vis_image = image_to_use.copy().convert('RGBA') # Create overlay at 512x512 overlay = Image.new('RGBA', (512, 512), (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) # Draw each mask with its corresponding color for part_name, mask in masks.items(): # Convert part name for color lookup color_key = part_name.replace('-', ' ').title().replace(' ', '-') is_selected = selected_parts and part_name in selected_parts # If selected, use color (with fallback). If unselected, use faint gray if is_selected: color = CLOTHING_COLORS.get(color_key, (255, 0, 255, 128)) # Default to magenta if no color found else: color = (180, 180, 180, 80) # Faint gray for unselected mask_array = np.array(mask) coords = np.where(mask_array > 0) for y, x in zip(coords[0], coords[1]): draw.point((x, y), fill=color) # Resize overlay to match original image size overlay = overlay.resize(original_size, Image.Resampling.LANCZOS) # Composite the overlay onto the original image vis_image = Image.alpha_composite(vis_image, overlay) return vis_image def inpaint(self, prompt, init_image, selected_parts=None, dilation_iterations=2) -> dict: image = self.make_square(init_image).resize((512,512)) if self.parser is not None: masks = self.parser.get_all_masks(image) masks = {k: v.resize((512,512)) for k, v in masks.items()} else: raise ValueError('Image Parser is Missing') logger.info(f'[generated required mask(s) at {time.time()}]') # Create combined mask for selected parts if selected_parts: combined_mask = Image.new('L', (512, 512), 0) for part in selected_parts: if part in masks: mask_array = np.array(masks[part]) kernel = np.ones((5,5), np.uint8) dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) dilated_mask = Image.fromarray(dilated_mask) combined_mask = Image.composite( Image.new('L', (512, 512), 255), combined_mask, dilated_mask ) else: # If no parts selected, use all clothing parts combined_mask = Image.new('L', (512, 512), 0) for part, mask in masks.items(): if part in ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']: mask_array = np.array(mask) kernel = np.ones((5,5), np.uint8) dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) dilated_mask = Image.fromarray(dilated_mask) combined_mask = Image.composite( Image.new('L', (512, 512), 255), combined_mask, dilated_mask ) # Run the model guidance_scale=7.5 num_samples = 3 with autocast("cuda"), torch.inference_mode(): images = self.pipe( num_inference_steps = 50, prompt=prompt['pos'], image=image, mask_image=combined_mask, guidance_scale=guidance_scale, num_images_per_prompt=num_samples, ).images images_output = [] for img in images: ch = PIL.Image.composite(img, image, combined_mask) fin_img = self.unmake_square(init_image, ch) images_output.append(fin_img) return images_output def process_segmentation(image, dilation_iterations=2): try: if image is None: raise gr.Error("Please upload an image") # Store original image inpainter.original_image = image.copy() # Create a processing copy at 512x512 proc_image = image.resize((512, 512), Image.Resampling.LANCZOS) # Get the main mask all_masks = inpainter.parser.get_all_masks(proc_image) if not all_masks: logger.error("No clothing detected in the image") raise gr.Error("No clothing detected in the image. Please try a different image.") inpainter.last_mask = all_masks # Only show main clothing parts for selection main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt'] masks = {k: v for k, v in all_masks.items() if k in main_parts} vis_image = inpainter.visualize_segmentation(image, masks, selected_parts=None) detected_parts = [k for k in masks.keys()] return vis_image, gr.update(choices=detected_parts, value=[]) except gr.Error as e: raise e except Exception as e: logger.error(f"Error processing segmentation: {str(e)}") raise gr.Error("Error processing the image. Please try a different image.") def update_dilation(image, selected_parts, dilation_iterations): try: if image is None or inpainter.last_mask is None: return image # Redilate all stored masks main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt'] masks = {} for part in main_parts: if part in inpainter.last_mask: mask_array = np.array(inpainter.last_mask[part]) kernel = np.ones((5,5), np.uint8) dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) masks[part] = Image.fromarray(dilated_mask) # Use original image for visualization vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts) return vis_image except Exception as e: logger.error(f"Error updating dilation: {str(e)}") return image def process_image(prompt, image, selected_parts, dilation_iterations): start_time = time.time() logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}") try: if image is None: logger.error("No image provided") raise gr.Error("Please upload an image") if not prompt: logger.error("No prompt provided") raise gr.Error("Please enter a prompt") if not selected_parts: logger.error("No parts selected") raise gr.Error("Please select at least one clothing part to modify") prompt_dict = {'pos': prompt} logger.info("Starting inpainting process") # Generate inpainted images # Convert selected_parts to lowercase/dash format selected_parts = [p.lower() for p in selected_parts] images = inpainter.inpaint(prompt_dict, image, selected_parts, dilation_iterations) if not images: logger.error("Inpainting failed to produce results") raise gr.Error("Failed to generate images. Please try again.") logger.info(f"Request processed in {time.time() - start_time:.2f} seconds") return images except Exception as e: logger.error(f"Error processing image: {str(e)}") raise gr.Error(f"Error processing image: {str(e)}") def update_selected_parts(image, selected_parts, dilation_iterations): try: if image is None or inpainter.last_mask is None: return image main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt'] masks = {} for part in main_parts: if part in inpainter.last_mask: mask_array = np.array(inpainter.last_mask[part]) kernel = np.ones((5,5), np.uint8) dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) masks[part] = Image.fromarray(dilated_mask) # Lowercase the selected_parts for comparison selected_parts = [p.lower() for p in selected_parts] if selected_parts else [] # Use original image for visualization vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts) return vis_image except Exception as e: logger.error(f"Error updating selected parts: {str(e)}") return image # Initialize the model init() # Create Gradio interface with gr.Blocks(title="ClothQuill - AI Clothing Inpainting") as demo: gr.Markdown("# ClothQuill - AI Clothing Inpainting") gr.Markdown("Upload an image to see segmented clothing parts, then select parts to modify and describe your changes") with gr.Row(): with gr.Column(): input_image = gr.Image( type="pil", label="Upload Image", scale=1, # This ensures the image maintains its aspect ratio height=None # Allow dynamic height based on content ) dilation_slider = gr.Slider( minimum=0, maximum=5, value=2, step=1, label="Mask Dilation", info="Adjust the mask dilation to control the area of modification" ) selected_parts = gr.CheckboxGroup( choices=[], label="Select parts to modify", value=[] ) prompt = gr.Textbox( label="Describe the clothing you want to generate", placeholder="e.g., A stylish black leather jacket" ) generate_btn = gr.Button("Generate") with gr.Column(): gallery = gr.Gallery( label="Generated Results", show_label=False, columns=2, height=None, # Allow dynamic height object_fit="contain" # Maintain aspect ratio ) # Add event handler for image upload input_image.upload( fn=process_segmentation, inputs=[input_image, dilation_slider], outputs=[input_image, selected_parts] ) # Add event handler for dilation changes dilation_slider.change( fn=update_dilation, inputs=[input_image, selected_parts,dilation_slider], outputs=input_image ) # Add event handler for generation generate_btn.click( fn=process_image, inputs=[prompt, input_image, selected_parts, dilation_slider], outputs=gallery ) # Add event handler for part selection changes selected_parts.change( fn=update_selected_parts, inputs=[input_image, selected_parts, dilation_slider], outputs=input_image ) if __name__ == "__main__": demo.launch(share=True)