Spaces:
ginipick
/
Running on Zero

FitGen / app.py
ginipick's picture
Update app.py
fc293bb verified
raw
history blame
18.5 kB
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, login
from leffa.transform import LeffaTransform
from leffa.model import LeffaModel
from leffa.inference import LeffaInference
from utils.garment_agnostic_mask_predictor import AutoMasker
from utils.densepose_predictor import DensePosePredictor
from utils.utils import resize_and_center
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import pipeline
import gradio as gr
import os
import random
import gc
# 상수 정의
MAX_SEED = 2**32 - 1
BASE_MODEL = "black-forest-labs/FLUX.1-dev"
MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
# 메모리 관리를 위한 데코레이터
def safe_model_call(func):
def wrapper(*args, **kwargs):
try:
clear_memory()
result = func(*args, **kwargs)
clear_memory()
return result
except Exception as e:
clear_memory()
print(f"Error in {func.__name__}: {str(e)}")
raise
return wrapper
# 메모리 관리 함수
def clear_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def setup_environment():
# 메모리 관리 설정
torch.cuda.empty_cache()
gc.collect()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.max_split_size_mb = 128
# Hugging Face 토큰 설정
global HF_TOKEN
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("Please set the HF_TOKEN environment variable")
login(token=HF_TOKEN)
# CUDA 설정
global device
device = "cuda" if torch.cuda.is_available() else "cpu"
# 전역 변수 초기화
fashion_pipe = None
translator = None
mask_predictor = None
densepose_predictor = None
vt_model = None
pt_model = None
vt_inference = None
pt_inference = None
device = None
HF_TOKEN = None
# 환경 설정 실행
setup_environment()
# 모델 관리 함수들
def initialize_fashion_pipe():
global fashion_pipe
if fashion_pipe is None:
clear_memory()
fashion_pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN
)
try:
fashion_pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print(f"Warning: Could not enable memory efficient attention: {e}")
fashion_pipe.enable_sequential_cpu_offload()
return fashion_pipe
@safe_model_call
def get_translator():
global translator
if translator is None:
translator = pipeline("translation",
model="Helsinki-NLP/opus-mt-ko-en",
device=device if device == "cuda" else -1)
return translator
@safe_model_call
def get_mask_predictor():
global mask_predictor
if mask_predictor is None:
mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)
return mask_predictor
@safe_model_call
def get_densepose_predictor():
global densepose_predictor
if densepose_predictor is None:
densepose_predictor = DensePosePredictor(
config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
weights_path="./ckpts/densepose/model_final_162be9.pkl",
)
return densepose_predictor
@safe_model_call
def get_vt_model():
global vt_model, vt_inference
if vt_model is None:
vt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
pretrained_model="./ckpts/virtual_tryon.pth"
)
vt_model = vt_model.half().to(device)
vt_inference = LeffaInference(model=vt_model)
return vt_model, vt_inference
@safe_model_call
def get_pt_model():
global pt_model, pt_inference
if pt_model is None:
pt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
pretrained_model="./ckpts/pose_transfer.pth"
)
pt_model = pt_model.half().to(device)
pt_inference = LeffaInference(model=pt_model)
return pt_model, pt_inference
def load_lora(pipe, lora_path):
try:
pipe.unload_lora_weights()
except:
pass
try:
pipe.load_lora_weights(lora_path)
return pipe
except Exception as e:
print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
return pipe
# 초기 설정 함수
def setup():
# Leffa 체크포인트 다운로드
snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
# 기본 모델 초기화
initialize_fashion_pipe()
# 유틸리티 함수
def contains_korean(text):
return any(ord('가') <= ord(char) <= ord('힣') for char in text)
# 메인 기능 함수들
@spaces.GPU()
@safe_model_call
def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
try:
# 한글 처리
if contains_korean(prompt):
translator = get_translator()
translated = translator(prompt)[0]['translation_text']
actual_prompt = translated
else:
actual_prompt = prompt
# 파이프라인 가져오기
pipe = initialize_fashion_pipe()
# LoRA 설정
if mode == "Generate Model":
pipe = load_lora(pipe, MODEL_LORA_REPO)
trigger_word = "fashion photography, professional model"
else:
pipe = load_lora(pipe, CLOTHES_LORA_REPO)
trigger_word = "upper clothing, fashion item"
# 파라미터 제한
width = min(width, 768)
height = min(height, 768)
steps = min(steps, 30)
# 시드 설정
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
# 진행률 표시
progress(0, "Starting fashion generation...")
# 이미지 생성
image = pipe(
prompt=f"{actual_prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
return image, seed
except Exception as e:
print(f"Error in generate_fashion: {str(e)}")
raise
@safe_model_call
def leffa_predict(src_image_path, ref_image_path, control_type):
try:
# 모델 초기화
if control_type == "virtual_tryon":
model, inference = get_vt_model()
else:
model, inference = get_pt_model()
mask_pred = get_mask_predictor()
dense_pred = get_densepose_predictor()
# 이미지 로드 및 전처리
src_image = Image.open(src_image_path)
ref_image = Image.open(ref_image_path)
src_image = resize_and_center(src_image, 768, 1024)
ref_image = resize_and_center(ref_image, 768, 1024)
src_image_array = np.array(src_image)
ref_image_array = np.array(ref_image)
# Mask 생성
if control_type == "virtual_tryon":
src_image = src_image.convert("RGB")
mask = mask_pred(src_image, "upper")["mask"]
else:
mask = Image.fromarray(np.ones_like(src_image_array) * 255)
# DensePose 예측
src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
src_image_seg_array = dense_pred.predict_seg(src_image_array)
if control_type == "virtual_tryon":
densepose = Image.fromarray(src_image_seg_array)
else:
densepose = Image.fromarray(src_image_iuv_array)
# Leffa 변환 및 추론
transform = LeffaTransform()
data = {
"src_image": [src_image],
"ref_image": [ref_image],
"mask": [mask],
"densepose": [densepose],
}
data = transform(data)
output = inference(data)
return np.array(output["generated_image"][0])
except Exception as e:
print(f"Error in leffa_predict: {str(e)}")
raise
@safe_model_call
def leffa_predict_vt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
@safe_model_call
def leffa_predict_pt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
# 초기 설정 실행
setup()
# Gradio 인터페이스
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
gr.Markdown("# 🎭 FitGen:Fashion Studio & Virtual Try-on")
with gr.Tabs():
# 패션 생성 탭
# 패션 생성 탭
with gr.Tab("Fashion Generation"):
with gr.Column():
mode = gr.Radio(
choices=["Generate Model", "Generate Clothes"],
label="Generation Mode",
value="Generate Model"
)
# 예제 프롬프트 설정
example_model_prompts = [
"professional fashion model, full body shot, standing pose, natural lighting, studio background, high fashion, elegant pose",
"fashion model portrait, upper body, confident pose, fashion photography, neutral background, professional lighting",
"stylish fashion model, three-quarter view, editorial pose, high-end fashion magazine style, minimal background"
]
example_clothes_prompts = [
"luxury designer sweater, cashmere material, cream color, cable knit pattern, high-end fashion, product photography",
"elegant business blazer, tailored fit, charcoal grey, premium wool fabric, professional wear",
"modern streetwear hoodie, oversized fit, minimalist design, premium cotton, urban style"
]
prompt = gr.TextArea(
label="Fashion Description (한글 또는 영어)",
placeholder="패션 모델이나 의류를 설명하세요..."
)
# 예제 섹션 추가
gr.Examples(
examples=example_model_prompts + example_clothes_prompts,
inputs=prompt,
label="Example Prompts"
)
with gr.Row():
with gr.Column():
result = gr.Image(label="Generated Result")
generate_button = gr.Button("Generate Fashion")
with gr.Accordion("Advanced Options", open=False):
with gr.Group():
with gr.Row():
with gr.Column():
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1,
maximum=20,
step=0.5,
value=7.0
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50, # 최대값 감소
step=1,
value=30
)
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0,
maximum=1,
step=0.01,
value=0.85
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024, # 최대값 감소
step=64,
value=512
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024, # 최대값 감소
step=64,
value=768
)
with gr.Row():
randomize_seed = gr.Checkbox(
True,
label="Randomize seed"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42
)
# 가상 피팅 탭
with gr.Tab("Virtual Try-on"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
vt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_src_image,
examples_per_page=5,
examples=["./ckpts/examples/person1/01350_00.jpg",
"./ckpts/examples/person1/01376_00.jpg",
"./ckpts/examples/person1/01416_00.jpg",
"./ckpts/examples/person1/05976_00.jpg",
"./ckpts/examples/person1/06094_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Garment Image")
vt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Garment Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_ref_image,
examples_per_page=5,
examples=["./ckpts/examples/garment/01449_00.jpg",
"./ckpts/examples/garment/01486_00.jpg",
"./ckpts/examples/garment/01853_00.jpg",
"./ckpts/examples/garment/02070_00.jpg",
"./ckpts/examples/garment/03553_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
vt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
vt_gen_button = gr.Button("Try-on")
# 포즈 전송 탭
with gr.Tab("Pose Transfer"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
pt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=pt_ref_image,
examples_per_page=5,
examples=["./ckpts/examples/person1/01350_00.jpg",
"./ckpts/examples/person1/01376_00.jpg",
"./ckpts/examples/person1/01416_00.jpg",
"./ckpts/examples/person1/05976_00.jpg",
"./ckpts/examples/person1/06094_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Target Pose Person Image")
pt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Target Pose Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=pt_src_image,
examples_per_page=5,
examples=["./ckpts/examples/person2/01850_00.jpg",
"./ckpts/examples/person2/01875_00.jpg",
"./ckpts/examples/person2/02532_00.jpg",
"./ckpts/examples/person2/02902_00.jpg",
"./ckpts/examples/person2/05346_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
pt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
pose_transfer_gen_button = gr.Button("Generate")
# 이벤트 핸들러
generate_button.click(
generate_fashion,
inputs=[prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
outputs=[result, seed]
)
vt_gen_button.click(
fn=leffa_predict_vt,
inputs=[vt_src_image, vt_ref_image],
outputs=[vt_gen_image]
)
pose_transfer_gen_button.click(
fn=leffa_predict_pt,
inputs=[pt_src_image, pt_ref_image],
outputs=[pt_gen_image]
)
# 앱 실행
demo.launch(share=True, server_port=7860)