File size: 5,476 Bytes
eb20c24
efb991e
6bf925b
 
 
 
 
 
 
 
efb991e
71f1758
 
 
 
 
 
 
 
 
 
 
 
6bf925b
10bd531
 
 
 
 
 
 
6bf925b
93be815
 
6bf925b
 
93be815
10bd531
2c740bc
 
 
93be815
 
2c740bc
 
93be815
 
 
 
2c740bc
10bd531
93be815
2c740bc
 
 
 
10bd531
93be815
2c740bc
 
6bf925b
93be815
 
 
 
 
10bd531
6bf925b
 
 
 
 
 
 
 
 
 
 
79e7646
2c740bc
6bf925b
2c740bc
 
6bf925b
eb20c24
6bf925b
79e7646
6bf925b
2c740bc
 
 
94d7430
93be815
2c740bc
ab47115
 
93be815
 
 
 
 
 
ab47115
 
2c740bc
93be815
2c740bc
6bf925b
 
ab47115
 
6bf925b
 
 
 
2c740bc
 
 
6bf925b
 
2c740bc
 
94d7430
2c740bc
6bf925b
93be815
6bf925b
 
 
 
93be815
6bf925b
eb20c24
 
2c740bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os

# Ensure DeepFashion2 model is downloaded early
MODEL_URL = "https://huggingface.co/Bingsu/adetailer/resolve/main/deepfashion2_yolov8s-seg.pt"
MODEL_PATH = "deepfashion2_yolov8s-seg.pt"
if not os.path.exists(MODEL_PATH):
    import urllib.request
    print("[INFO] Downloading DeepFashion2 YOLOv8 model...")
    urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
    print("[INFO] Model downloaded.")



import traceback
from datetime import datetime
import torch, gc
from PIL import Image
import gradio as gr

from inference import generate_with_lora
from background_edit import run_background_removal_and_inpaint, run_clothing_inpaint


# ─────────────── Helpers ───────────────
def _print_trace():
    traceback.print_exc()

def unload_models():
    torch.cuda.empty_cache()
    gc.collect()

def safe_generate_all_steps(
    image,
    prompt_1, neg_1, strength_1, guidance_1,
    prompt_2, neg_2, guidance_2,
    prompt_3, neg_3, guidance_3
):
    try:
        if image is None:
            raise gr.Error("Please upload an image first.")

        # Step 1: Headshot Refinement
        print("[INFO] Step 1: Refining headshot...", flush=True)
        refined = generate_with_lora(
            image=image,
            prompt=prompt_1,
            negative_prompt=neg_1,
            strength=strength_1,
            guidance_scale=guidance_1,
        )

        # Save intermediate result to disk
        os.makedirs("./outputs", exist_ok=True)
        ts = datetime.now().strftime("%Y%m%d_%H%M%S")
        path = f"./outputs/step1_result_{ts}.png"
        refined.save(path)

        # Step 2: Background Inpainting
        print("[INFO] Step 2: Inpainting background...", flush=True)
        unload_models()
        with_bg = run_background_removal_and_inpaint(
            image_path=path,
            prompt=prompt_2,
            negative_prompt=neg_2,
            guidance_scale=guidance_2
        )

        # Step 3: Clothing Inpainting
        print("[INFO] Step 3: Inpainting clothing...", flush=True)
        final, err = run_clothing_inpaint(
            with_bg,
            prompt_3,
            neg_3,
            guidance_3
        )
        if err:
            return refined, with_bg, None, err
        return refined, with_bg, final, ""

    except gr.Error as e:
        return None, None, None, f"πŸ›‘ {str(e)}"
    except Exception as e:
        _print_trace()
        return None, None, None, f"❌ Unexpected Error: {type(e).__name__}: {str(e)}"

# ─────────────── Gradio UI ───────────────
with gr.Blocks() as demo:
    gr.Markdown("## 🧠 Full Headshot + Background + Clothing Generator (One Click)")

    with gr.Row():
        input_image = gr.Image(type="pil", label="Upload Headshot")

    gr.Markdown("### Step 1: Headshot Refinement (LoRA)")
    with gr.Row():
        prompt_1 = gr.Textbox(label="Headshot Prompt", value="a professional corporate headshot of a confident woman in her 30s with blow dried hair, natural smile, soft lighting, clean studio background, realistic photo, high detail, shallow depth of field")
        neg_1 = gr.Textbox(label="Headshot Negative Prompt", value="cartoon, anime, painting, illustration, low quality, overexposed, distorted face, exaggerated features, blurry background")
    with gr.Row():
        strength_1 = gr.Slider(0.1, 1.0, value=0.2, step=0.05, label="Refinement Strength")
        guidance_1 = gr.Slider(1, 20, value=17, step=0.5, label="Guidance Scale (Headshot)")

    gr.Markdown("### Step 2: Background Inpainting (SDXL)")
    with gr.Row():
        prompt_2 = gr.Textbox(label="Background Prompt", value="modern startup office, open-plan layout, natural daylight, glass walls, minimalistic decor, desks with computers, warm soft lighting, realistic environment")
        neg_2 = gr.Textbox(label="Background Negative Prompt", value="cluttered space, fantasy architecture, cartoon, low-res textures, empty background, distorted shapes, harsh shadowsh")
    with gr.Row():
        guidance_2 = gr.Slider(1, 20, value=10, step=0.5, label="Guidance Scale (Background)")

    gr.Markdown("### Step 3: Clothing Replacement")
    with gr.Row():
        prompt_3 = gr.Textbox(label="Clothing Prompt", value="tailored women's business suit, white blouse, blazer and pencil skirt, elegant corporate style, modern, professional lighting")
        neg_3 = gr.Textbox(label="Clothing Negative Prompt", value="casual clothes, hoodie, jeans, fantasy outfit, cartoon, distorted textures, glitch, unrealistic proportions")
    with gr.Row():
        guidance_3 = gr.Slider(1, 20, value=17.0, step=0.5, label="Clothing Guidance Scale")

    go_btn = gr.Button("✨ Run Full Pipeline (All 3 Steps)")

    with gr.Row():
        output_refined = gr.Image(type="pil", label="Step 1: Refined Headshot")
        output_bg = gr.Image(type="pil", label="Step 2: With New Background")
        output_final = gr.Image(type="pil", label="Step 3: Final with New Clothing")

    error_box = gr.Markdown(label="Error", value="", visible=True)

    go_btn.click(
        fn=safe_generate_all_steps,
        inputs=[
            input_image,
            prompt_1, neg_1, strength_1, guidance_1,
            prompt_2, neg_2, guidance_2,
            prompt_3, neg_3, guidance_3
        ],
        outputs=[output_refined, output_bg, output_final, error_box]
    )

demo.launch(debug=True)