import torch from torch import autocast from diffusers import StableDiffusionInpaintPipeline import gradio as gr import traceback import base64 from io import BytesIO import os import PIL import json import requests import logging import time import warnings warnings.filterwarnings("ignore") # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger('looks.studio') # Model paths SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes" STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting" # Global variables for models parser = None model = None inpainter = None def get_device(): if torch.cuda.is_available(): device = "cuda" logger.info("Using GPU") else: device = "cpu" logger.info("Using CPU") return device def init(): global parser global model global inpainter start_time = time.time() logger.info("Starting application initialization") try: device = get_device() # Initialize Segformer parser logger.info("Initializing Segformer parser...") from parser.segformer_parser import SegformerParser parser = SegformerParser(SEGFORMER_MODEL) # Initialize Stable Diffusion model logger.info("Initializing Stable Diffusion model...") model = StableDiffusionInpaintPipeline.from_pretrained( STABLE_DIFFUSION_MODEL, safety_checker=None, revision="fp16" if device == "cuda" else None, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) # Initialize inpainter logger.info("Initializing inpainter...") inpainter = ClothingInpainter(model=model, parser=parser) logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds") except Exception as e: logger.error(f"Error initializing application: {str(e)}") raise e class ClothingInpainter: def __init__(self, model_path=None, model=None, parser=None): self.device = get_device() if model_path is None and model is None: raise ValueError('No model provided!') if model_path is not None: self.pipe = StableDiffusionInpaintPipeline.from_pretrained( model_path, safety_checker=None, revision="fp16" if self.device == "cuda" else None, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) else: self.pipe = model self.parser = parser def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)): x, y = im.size size = max(min_size, x, y) new_im = PIL.Image.new('RGBA', (size, size), fill_color) new_im.paste(im, (int((size - x) / 2), int((size - y) / 2))) return new_im.convert('RGB') def unmake_square(self, init_im, op_im, min_size=256, rs_size=512): x, y = init_im.size size = max(min_size, x, y) factor = rs_size/size return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\ int((size+x) * factor / 2), int((size+y) * factor / 2))) def inpaint(self, prompt, init_image, parser=None) -> dict: image = self.make_square(init_image).resize((512,512)) if self.parser is not None: mask = self.parser.get_image_mask(image) mask = mask.resize((512,512)) elif parser is not None: mask = parser.get_image_mask(image) mask = mask.resize((512,512)) else: raise ValueError('Image Parser is Missing') logger.info(f'[generated required mask(s) at {time.time()}]') # Run the model guidance_scale=7.5 num_samples = 3 with autocast("cuda"), torch.inference_mode(): images = self.pipe( num_inference_steps = 50, prompt=prompt['pos'], image=image, mask_image=mask, guidance_scale=guidance_scale, num_images_per_prompt=num_samples, ).images images_output = [] for img in images: ch = PIL.Image.composite(img,image, mask.convert('L')) fin_img = self.unmake_square(init_image, ch) images_output.append(fin_img) return images_output def process_image(prompt, image): start_time = time.time() logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}") try: if image is None: logger.error("No image provided") raise gr.Error("Please upload an image") if not prompt: logger.error("No prompt provided") raise gr.Error("Please enter a prompt") prompt_dict = {'pos': prompt} logger.info("Starting inpainting process") images = inpainter.inpaint(prompt_dict, image) if not images: logger.error("Inpainting failed to produce results") raise gr.Error("Failed to generate images. Please try again.") logger.info(f"Request processed in {time.time() - start_time:.2f} seconds") return images except Exception as e: logger.error(f"Error processing image: {str(e)}") raise gr.Error(f"Error processing image: {str(e)}") # Initialize the model init() # Create Gradio interface with gr.Blocks(title="Looks.Studio - AI Clothing Inpainting") as demo: gr.Markdown("# Looks.Studio - AI Clothing Inpainting") gr.Markdown("Upload an image and describe the clothing you want to generate") with gr.Row(): with gr.Column(): input_image = gr.Image( type="pil", label="Upload Image", height=512 ) prompt = gr.Textbox(label="Describe the clothing you want to generate") generate_btn = gr.Button("Generate") with gr.Column(): gallery = gr.Gallery( label="Generated Images", show_label=False, columns=2, height=512 ) generate_btn.click( fn=process_image, inputs=[prompt, input_image], outputs=gallery ) if __name__ == "__main__": demo.launch(share=True)