theoracle's picture
πŸ› Fix: ensure DeepFashion2 model is downloaded before loading
6bf925b
raw
history blame
4.97 kB
import os
import traceback
from datetime import datetime
import torch, gc
from PIL import Image
import gradio as gr
from inference import generate_with_lora
from background_edit import run_background_removal_and_inpaint, run_clothing_inpaint
# Ensure DeepFashion2 model is downloaded early
MODEL_URL = "https://huggingface.co/Bingsu/adetailer/resolve/main/deepfashion2_yolov8s-seg.pt"
MODEL_PATH = "deepfashion2_yolov8s-seg.pt"
if not os.path.exists(MODEL_PATH):
import urllib.request
print("[INFO] Downloading DeepFashion2 YOLOv8 model...")
urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
print("[INFO] Model downloaded.")
# ─────────────── Helpers ───────────────
def _print_trace():
traceback.print_exc()
def unload_models():
torch.cuda.empty_cache()
gc.collect()
def safe_generate_all_steps(
image,
prompt_1, neg_1, strength_1, guidance_1,
prompt_2, neg_2, guidance_2,
prompt_3, neg_3, guidance_3
):
try:
if image is None:
raise gr.Error("Please upload an image first.")
# Step 1: Headshot Refinement
print("[INFO] Step 1: Refining headshot...", flush=True)
refined = generate_with_lora(
image=image,
prompt=prompt_1,
negative_prompt=neg_1,
strength=strength_1,
guidance_scale=guidance_1,
)
# Save intermediate result to disk
os.makedirs("./outputs", exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
path = f"./outputs/step1_result_{ts}.png"
refined.save(path)
# Step 2: Background Inpainting
print("[INFO] Step 2: Inpainting background...", flush=True)
unload_models()
with_bg = run_background_removal_and_inpaint(
image_path=path,
prompt=prompt_2,
negative_prompt=neg_2,
guidance_scale=guidance_2
)
# Step 3: Clothing Inpainting
print("[INFO] Step 3: Inpainting clothing...", flush=True)
final, err = run_clothing_inpaint(
with_bg,
prompt_3,
neg_3,
guidance_3
)
if err:
return refined, with_bg, None, err
return refined, with_bg, final, ""
except gr.Error as e:
return None, None, None, f"πŸ›‘ {str(e)}"
except Exception as e:
_print_trace()
return None, None, None, f"❌ Unexpected Error: {type(e).__name__}: {str(e)}"
# ─────────────── Gradio UI ───────────────
with gr.Blocks() as demo:
gr.Markdown("## 🧠 Full Headshot + Background + Clothing Generator (One Click)")
with gr.Row():
input_image = gr.Image(type="pil", label="Upload Headshot")
gr.Markdown("### Step 1: Headshot Refinement (LoRA)")
with gr.Row():
prompt_1 = gr.Textbox(label="Headshot Prompt", value="a professional headshot of a confident woman in her 30s with blonde hair")
neg_1 = gr.Textbox(label="Headshot Negative Prompt", value="deformed, cartoon, anime, sketch, blurry, low quality")
with gr.Row():
strength_1 = gr.Slider(0.1, 1.0, value=0.2, step=0.05, label="Refinement Strength")
guidance_1 = gr.Slider(1, 20, value=17, step=0.5, label="Guidance Scale (Headshot)")
gr.Markdown("### Step 2: Background Inpainting (SDXL)")
with gr.Row():
prompt_2 = gr.Textbox(label="Background Prompt", value="modern hospital background, clean, soft lighting")
neg_2 = gr.Textbox(label="Background Negative Prompt", value="fantasy, cartoon, cluttered, sketch")
with gr.Row():
guidance_2 = gr.Slider(1, 20, value=10, step=0.5, label="Guidance Scale (Background)")
gr.Markdown("### Step 3: Clothing Replacement")
with gr.Row():
prompt_3 = gr.Textbox(label="Clothing Prompt", value="white female CEO professional blazer, clean look")
neg_3 = gr.Textbox(label="Clothing Negative Prompt", value="hoodie, casual wear, fantasy, cartoon, jeans, distorted, blurry")
with gr.Row():
guidance_3 = gr.Slider(1, 20, value=17.0, step=0.5, label="Clothing Guidance Scale")
go_btn = gr.Button("✨ Run Full Pipeline (All 3 Steps)")
with gr.Row():
output_refined = gr.Image(type="pil", label="Step 1: Refined Headshot")
output_bg = gr.Image(type="pil", label="Step 2: With New Background")
output_final = gr.Image(type="pil", label="Step 3: Final with New Clothing")
error_box = gr.Markdown(label="Error", value="", visible=True)
go_btn.click(
fn=safe_generate_all_steps,
inputs=[
input_image,
prompt_1, neg_1, strength_1, guidance_1,
prompt_2, neg_2, guidance_2,
prompt_3, neg_3, guidance_3
],
outputs=[output_refined, output_bg, output_final, error_box]
)
demo.launch(debug=True)