Spaces:
Running
Running
import torch | |
from torch import autocast | |
from diffusers import StableDiffusionInpaintPipeline | |
import gradio as gr | |
import traceback | |
import base64 | |
from io import BytesIO | |
import os | |
import PIL | |
import json | |
import requests | |
import logging | |
import time | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger('looks.studio') | |
# Model paths | |
SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes" | |
STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting" | |
# Global variables for models | |
parser = None | |
model = None | |
inpainter = None | |
def get_device(): | |
if torch.cuda.is_available(): | |
device = "cuda" | |
logger.info("Using GPU") | |
else: | |
device = "cpu" | |
logger.info("Using CPU") | |
return device | |
def init(): | |
global parser | |
global model | |
global inpainter | |
start_time = time.time() | |
logger.info("Starting application initialization") | |
try: | |
device = get_device() | |
# Initialize Segformer parser | |
logger.info("Initializing Segformer parser...") | |
from parser.segformer_parser import SegformerParser | |
parser = SegformerParser(SEGFORMER_MODEL) | |
# Initialize Stable Diffusion model | |
logger.info("Initializing Stable Diffusion model...") | |
model = StableDiffusionInpaintPipeline.from_pretrained( | |
STABLE_DIFFUSION_MODEL, | |
safety_checker=None, | |
revision="fp16" if device == "cuda" else None, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
# Initialize inpainter | |
logger.info("Initializing inpainter...") | |
inpainter = ClothingInpainter(model=model, parser=parser) | |
logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds") | |
except Exception as e: | |
logger.error(f"Error initializing application: {str(e)}") | |
raise e | |
class ClothingInpainter: | |
def __init__(self, model_path=None, model=None, parser=None): | |
self.device = get_device() | |
if model_path is None and model is None: | |
raise ValueError('No model provided!') | |
if model_path is not None: | |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
model_path, | |
safety_checker=None, | |
revision="fp16" if self.device == "cuda" else None, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
).to(self.device) | |
else: | |
self.pipe = model | |
self.parser = parser | |
def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)): | |
x, y = im.size | |
size = max(min_size, x, y) | |
new_im = PIL.Image.new('RGBA', (size, size), fill_color) | |
new_im.paste(im, (int((size - x) / 2), int((size - y) / 2))) | |
return new_im.convert('RGB') | |
def unmake_square(self, init_im, op_im, min_size=256, rs_size=512): | |
x, y = init_im.size | |
size = max(min_size, x, y) | |
factor = rs_size/size | |
return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\ | |
int((size+x) * factor / 2), int((size+y) * factor / 2))) | |
def inpaint(self, prompt, init_image, parser=None) -> dict: | |
image = self.make_square(init_image).resize((512,512)) | |
if self.parser is not None: | |
mask = self.parser.get_image_mask(image) | |
mask = mask.resize((512,512)) | |
elif parser is not None: | |
mask = parser.get_image_mask(image) | |
mask = mask.resize((512,512)) | |
else: | |
raise ValueError('Image Parser is Missing') | |
logger.info(f'[generated required mask(s) at {time.time()}]') | |
# Run the model | |
guidance_scale=7.5 | |
num_samples = 3 | |
with autocast("cuda"), torch.inference_mode(): | |
images = self.pipe( | |
num_inference_steps = 50, | |
prompt=prompt['pos'], | |
image=image, | |
mask_image=mask, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=num_samples, | |
).images | |
images_output = [] | |
for img in images: | |
ch = PIL.Image.composite(img,image, mask.convert('L')) | |
fin_img = self.unmake_square(init_image, ch) | |
images_output.append(fin_img) | |
return images_output | |
def process_image(prompt, image): | |
start_time = time.time() | |
logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}") | |
try: | |
if image is None: | |
logger.error("No image provided") | |
raise gr.Error("Please upload an image") | |
if not prompt: | |
logger.error("No prompt provided") | |
raise gr.Error("Please enter a prompt") | |
prompt_dict = {'pos': prompt} | |
logger.info("Starting inpainting process") | |
images = inpainter.inpaint(prompt_dict, image) | |
if not images: | |
logger.error("Inpainting failed to produce results") | |
raise gr.Error("Failed to generate images. Please try again.") | |
logger.info(f"Request processed in {time.time() - start_time:.2f} seconds") | |
return images | |
except Exception as e: | |
logger.error(f"Error processing image: {str(e)}") | |
raise gr.Error(f"Error processing image: {str(e)}") | |
# Initialize the model | |
init() | |
# Create Gradio interface | |
with gr.Blocks(title="Looks.Studio - AI Clothing Inpainting") as demo: | |
gr.Markdown("# Looks.Studio - AI Clothing Inpainting") | |
gr.Markdown("Upload an image and describe the clothing you want to generate") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
type="pil", | |
label="Upload Image", | |
height=512 | |
) | |
prompt = gr.Textbox(label="Describe the clothing you want to generate") | |
generate_btn = gr.Button("Generate") | |
with gr.Column(): | |
gallery = gr.Gallery( | |
label="Generated Images", | |
show_label=False, | |
columns=2, | |
height=512 | |
) | |
generate_btn.click( | |
fn=process_image, | |
inputs=[prompt, input_image], | |
outputs=gallery | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |