ClothQuill / app.py
Bismay
Rollback changes
adc77ba
import torch
from torch import autocast
from diffusers import StableDiffusionInpaintPipeline
import gradio as gr
import traceback
import base64
from io import BytesIO
import os
# import sys
import PIL
import json
import requests
import logging
import time
import warnings
import numpy as np
from PIL import Image, ImageDraw
import cv2
warnings.filterwarnings("ignore")
# sys.path.insert(1, './parser')
# from parser.schp_masker import *
from parser.segformer_parser import SegformerParser
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('clothquill')
# Model paths
SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes"
STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting"
# Global variables for models
parser = None
model = None
inpainter = None
original_image = None # Store the original uploaded image
# Color mapping for different clothing parts
CLOTHING_COLORS = {
'Background': (0, 0, 0, 0), # Transparent
'Hat': (255, 0, 0, 128), # Red
'Hair': (0, 255, 0, 128), # Green
'Glove': (0, 0, 255, 128), # Blue
'Sunglasses': (255, 255, 0, 128), # Yellow
'Upper-clothes': (255, 0, 255, 128), # Magenta
'Dress': (0, 255, 255, 128), # Cyan
'Coat': (128, 0, 0, 128), # Dark Red
'Socks': (0, 128, 0, 128), # Dark Green
'Pants': (0, 0, 128, 128), # Dark Blue
'Jumpsuits': (128, 128, 0, 128), # Dark Yellow
'Scarf': (128, 0, 128, 128), # Dark Magenta
'Skirt': (0, 128, 128, 128), # Dark Cyan
'Face': (192, 192, 192, 128), # Light Gray
'Left-arm': (64, 64, 64, 128), # Dark Gray
'Right-arm': (64, 64, 64, 128), # Dark Gray
'Left-leg': (32, 32, 32, 128), # Very Dark Gray
'Right-leg': (32, 32, 32, 128), # Very Dark Gray
'Left-shoe': (16, 16, 16, 128), # Almost Black
'Right-shoe': (16, 16, 16, 128), # Almost Black
}
def get_device():
if torch.cuda.is_available():
device = "cuda"
logger.info("Using GPU")
else:
device = "cpu"
logger.info("Using CPU")
return device
def init():
global parser
global model
global inpainter
start_time = time.time()
logger.info("Starting application initialization")
try:
device = get_device()
# Check if models directory exists
if not os.path.exists("models"):
logger.info("Creating models directory...")
from download_models import download_models
download_models()
# Initialize Segformer parser
logger.info("Initializing Segformer parser...")
parser = SegformerParser(SEGFORMER_MODEL)
# Initialize Stable Diffusion model
logger.info("Initializing Stable Diffusion model...")
model = StableDiffusionInpaintPipeline.from_pretrained(
STABLE_DIFFUSION_MODEL,
safety_checker=None,
revision="fp16" if device == "cuda" else None,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
# Initialize inpainter
logger.info("Initializing inpainter...")
inpainter = ClothingInpainter(model=model, parser=parser)
logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds")
except Exception as e:
logger.error(f"Error initializing application: {str(e)}")
raise e
class ClothingInpainter:
def __init__(self, model_path=None, model=None, parser=None):
self.device = get_device()
self.last_mask = None # Store the last generated mask
self.original_image = None # Store the original image
if model_path is None and model is None:
raise ValueError('No model provided!')
if model_path is not None:
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_path,
safety_checker=None,
revision="fp16" if self.device == "cuda" else None,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
else:
self.pipe = model
self.parser = parser
def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)):
x, y = im.size
size = max(min_size, x, y)
new_im = PIL.Image.new('RGBA', (size, size), fill_color)
new_im.paste(im, (int((size - x) / 2), int((size - y) / 2)))
return new_im.convert('RGB')
def unmake_square(self, init_im, op_im, min_size=256, rs_size=512):
x, y = init_im.size
size = max(min_size, x, y)
factor = rs_size/size
return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\
int((size+x) * factor / 2), int((size+y) * factor / 2)))
def visualize_segmentation(self, image, masks, selected_parts=None):
"""Visualize segmentation with colored overlays for selected parts and gray for unselected."""
# Always use original image if available
image_to_use = self.original_image if self.original_image is not None else image
# Create a copy of the original image
original_size = image_to_use.size
vis_image = image_to_use.copy().convert('RGBA')
# Create overlay at 512x512
overlay = Image.new('RGBA', (512, 512), (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
# Draw each mask with its corresponding color
for part_name, mask in masks.items():
# Convert part name for color lookup
color_key = part_name.replace('-', ' ').title().replace(' ', '-')
is_selected = selected_parts and part_name in selected_parts
# If selected, use color (with fallback). If unselected, use faint gray
if is_selected:
color = CLOTHING_COLORS.get(color_key, (255, 0, 255, 128)) # Default to magenta if no color found
else:
color = (180, 180, 180, 80) # Faint gray for unselected
mask_array = np.array(mask)
coords = np.where(mask_array > 0)
for y, x in zip(coords[0], coords[1]):
draw.point((x, y), fill=color)
# Resize overlay to match original image size
overlay = overlay.resize(original_size, Image.Resampling.LANCZOS)
# Composite the overlay onto the original image
vis_image = Image.alpha_composite(vis_image, overlay)
return vis_image
def inpaint(self, prompt, init_image, selected_parts=None, dilation_iterations=2) -> dict:
image = self.make_square(init_image).resize((512,512))
if self.parser is not None:
masks = self.parser.get_all_masks(image)
masks = {k: v.resize((512,512)) for k, v in masks.items()}
else:
raise ValueError('Image Parser is Missing')
logger.info(f'[generated required mask(s) at {time.time()}]')
# Create combined mask for selected parts
if selected_parts:
combined_mask = Image.new('L', (512, 512), 0)
for part in selected_parts:
if part in masks:
mask_array = np.array(masks[part])
kernel = np.ones((5,5), np.uint8)
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
dilated_mask = Image.fromarray(dilated_mask)
combined_mask = Image.composite(
Image.new('L', (512, 512), 255),
combined_mask,
dilated_mask
)
else:
# If no parts selected, use all clothing parts
combined_mask = Image.new('L', (512, 512), 0)
for part, mask in masks.items():
if part in ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']:
mask_array = np.array(mask)
kernel = np.ones((5,5), np.uint8)
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
dilated_mask = Image.fromarray(dilated_mask)
combined_mask = Image.composite(
Image.new('L', (512, 512), 255),
combined_mask,
dilated_mask
)
# Run the model
guidance_scale=7.5
num_samples = 3
with autocast("cuda"), torch.inference_mode():
images = self.pipe(
num_inference_steps = 50,
prompt=prompt['pos'],
image=image,
mask_image=combined_mask,
guidance_scale=guidance_scale,
num_images_per_prompt=num_samples,
).images
images_output = []
for img in images:
ch = PIL.Image.composite(img, image, combined_mask)
fin_img = self.unmake_square(init_image, ch)
images_output.append(fin_img)
return images_output
def process_segmentation(image, dilation_iterations=2):
try:
if image is None:
raise gr.Error("Please upload an image")
# Store original image
inpainter.original_image = image.copy()
# Create a processing copy at 512x512
proc_image = image.resize((512, 512), Image.Resampling.LANCZOS)
# Get the main mask
all_masks = inpainter.parser.get_all_masks(proc_image)
if not all_masks:
logger.error("No clothing detected in the image")
raise gr.Error("No clothing detected in the image. Please try a different image.")
inpainter.last_mask = all_masks
# Only show main clothing parts for selection
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
masks = {k: v for k, v in all_masks.items() if k in main_parts}
vis_image = inpainter.visualize_segmentation(image, masks, selected_parts=None)
detected_parts = [k for k in masks.keys()]
return vis_image, gr.update(choices=detected_parts, value=[])
except gr.Error as e:
raise e
except Exception as e:
logger.error(f"Error processing segmentation: {str(e)}")
raise gr.Error("Error processing the image. Please try a different image.")
def update_dilation(image, selected_parts, dilation_iterations):
try:
if image is None or inpainter.last_mask is None:
return image
# Redilate all stored masks
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
masks = {}
for part in main_parts:
if part in inpainter.last_mask:
mask_array = np.array(inpainter.last_mask[part])
kernel = np.ones((5,5), np.uint8)
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
masks[part] = Image.fromarray(dilated_mask)
# Use original image for visualization
vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts)
return vis_image
except Exception as e:
logger.error(f"Error updating dilation: {str(e)}")
return image
def process_image(prompt, image, selected_parts, dilation_iterations):
start_time = time.time()
logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}")
try:
if image is None:
logger.error("No image provided")
raise gr.Error("Please upload an image")
if not prompt:
logger.error("No prompt provided")
raise gr.Error("Please enter a prompt")
if not selected_parts:
logger.error("No parts selected")
raise gr.Error("Please select at least one clothing part to modify")
prompt_dict = {'pos': prompt}
logger.info("Starting inpainting process")
# Generate inpainted images
# Convert selected_parts to lowercase/dash format
selected_parts = [p.lower() for p in selected_parts]
images = inpainter.inpaint(prompt_dict, image, selected_parts, dilation_iterations)
if not images:
logger.error("Inpainting failed to produce results")
raise gr.Error("Failed to generate images. Please try again.")
logger.info(f"Request processed in {time.time() - start_time:.2f} seconds")
return images
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
raise gr.Error(f"Error processing image: {str(e)}")
def update_selected_parts(image, selected_parts, dilation_iterations):
try:
if image is None or inpainter.last_mask is None:
return image
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
masks = {}
for part in main_parts:
if part in inpainter.last_mask:
mask_array = np.array(inpainter.last_mask[part])
kernel = np.ones((5,5), np.uint8)
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
masks[part] = Image.fromarray(dilated_mask)
# Lowercase the selected_parts for comparison
selected_parts = [p.lower() for p in selected_parts] if selected_parts else []
# Use original image for visualization
vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts)
return vis_image
except Exception as e:
logger.error(f"Error updating selected parts: {str(e)}")
return image
# Initialize the model
init()
# Create Gradio interface
with gr.Blocks(title="ClothQuill - AI Clothing Inpainting") as demo:
gr.Markdown("# ClothQuill - AI Clothing Inpainting")
gr.Markdown("Upload an image to see segmented clothing parts, then select parts to modify and describe your changes")
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Upload Image",
scale=1, # This ensures the image maintains its aspect ratio
height=None # Allow dynamic height based on content
)
dilation_slider = gr.Slider(
minimum=0,
maximum=5,
value=2,
step=1,
label="Mask Dilation",
info="Adjust the mask dilation to control the area of modification"
)
selected_parts = gr.CheckboxGroup(
choices=[],
label="Select parts to modify",
value=[]
)
prompt = gr.Textbox(
label="Describe the clothing you want to generate",
placeholder="e.g., A stylish black leather jacket"
)
generate_btn = gr.Button("Generate")
with gr.Column():
gallery = gr.Gallery(
label="Generated Results",
show_label=False,
columns=2,
height=None, # Allow dynamic height
object_fit="contain" # Maintain aspect ratio
)
# Add event handler for image upload
input_image.upload(
fn=process_segmentation,
inputs=[input_image, dilation_slider],
outputs=[input_image, selected_parts]
)
# Add event handler for dilation changes
dilation_slider.change(
fn=update_dilation,
inputs=[input_image, selected_parts,dilation_slider],
outputs=input_image
)
# Add event handler for generation
generate_btn.click(
fn=process_image,
inputs=[prompt, input_image, selected_parts, dilation_slider],
outputs=gallery
)
# Add event handler for part selection changes
selected_parts.change(
fn=update_selected_parts,
inputs=[input_image, selected_parts, dilation_slider],
outputs=input_image
)
if __name__ == "__main__":
demo.launch(share=True)