File size: 3,273 Bytes
2c740bc
94d7430
10bd531
e0eea92
10bd531
94d7430
79e7646
2c740bc
10bd531
 
 
 
 
 
 
 
 
 
 
2c740bc
10bd531
2c740bc
 
 
 
 
 
 
 
 
 
 
 
10bd531
2c740bc
 
 
 
 
10bd531
2c740bc
 
 
 
10bd531
2c740bc
79e7646
2c740bc
 
 
 
 
 
 
79e7646
2c740bc
 
 
 
 
94d7430
2c740bc
 
 
 
10bd531
2c740bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94d7430
2c740bc
 
 
 
 
79e7646
2c740bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

import os
import traceback
from datetime import datetime
import torch, gc
from PIL import Image
import gradio as gr

from inference import generate_with_lora
from background_edit import run_background_removal_and_inpaint

# ───────────────────── Helpers ─────────────────────
def _print_trace():
    traceback.print_exc()

def unload_models():
    torch.cuda.empty_cache()
    gc.collect()

def safe_generate_and_inpaint(image, prompt, negative_prompt, strength, guidance_scale):
    try:
        if image is None:
            raise gr.Error("Please upload an image first.")

        # Step 1: Refinement
        print("[INFO] Step 1: Generating refined image...", flush=True)
        refined = generate_with_lora(
            image=image,
            prompt=prompt,
            negative_prompt=negative_prompt,
            strength=strength,
            guidance_scale=guidance_scale,
        )

        # Save to disk
        os.makedirs("./outputs", exist_ok=True)
        ts = datetime.now().strftime("%Y%m%d_%H%M%S")
        path = f"./outputs/step1_result_{ts}.png"
        refined.save(path)

        # Step 2: Background removal and inpainting
        print("[INFO] Step 2: Inpainting background...", flush=True)
        unload_models()
        result = run_background_removal_and_inpaint(path, prompt, negative_prompt)

        return refined, result, ""

    except gr.Error as e:
        return None, None, f"πŸ›‘ {str(e)}"
    except Exception as e:
        _print_trace()
        return None, None, f"❌ Unexpected Error: {type(e).__name__}: {str(e)}"

# ───────────────────── Gradio UI ─────────────────────
with gr.Blocks() as demo:
    gr.Markdown("## πŸ–ΌοΈ AI Headshot Enhancer + Background Replacer")
    gr.Markdown("Upload a headshot, adjust the prompt, and click one button. We'll refine the image and replace the background automatically.")

    with gr.Row():
        input_image = gr.Image(type="pil", label="Upload Headshot")

    with gr.Row():
        prompt = gr.Textbox(
            label="Prompt",
            value="a professional corporate headshot of a confident woman in her 30s with blonde hair"
        )
        negative_prompt = gr.Textbox(
            label="Negative Prompt",
            value="deformed, cartoon, anime, illustration, painting, drawing, sketch, low resolution, blurry, out of focus, pixelated"
        )

    with gr.Row():
        strength = gr.Slider(0.1, 1.0, value=0.20, step=0.05, label="Refinement Strength")
        guidance = gr.Slider(1, 20, value=17.0, step=0.5, label="Guidance Scale")

    go_btn = gr.Button("✨ Run Full Process (Refine + Inpaint)")

    with gr.Row():
        output_refined = gr.Image(type="pil", label="Step 1: Refined Headshot")
        output_final   = gr.Image(type="pil", label="Step 2: Final Image with Background")

    error_box = gr.Markdown(label="Error", value="", visible=True)

    go_btn.click(
        fn=safe_generate_and_inpaint,
        inputs=[input_image, prompt, negative_prompt, strength, guidance],
        outputs=[output_refined, output_final, error_box]
    )

demo.launch(debug=True)