File size: 3,211 Bytes
10bd531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os

# ── Set cache/config dirs ──
hf_home = "/data/.cache/huggingface"
yolo_cfg = "/data/ultralytics"
os.makedirs(hf_home, exist_ok=True)
os.makedirs(yolo_cfg, exist_ok=True)
os.environ["HF_HOME"] = hf_home
os.environ["YOLO_CONFIG_DIR"] = yolo_cfg

from ultralytics import YOLO
import numpy as np
import torch
from PIL import Image
import cv2
from diffusers import StableDiffusionXLInpaintPipeline
import gradio as gr

# ---- utils ----
def pil_to_cv2(pil_img):
    return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

def cv2_to_pil(cv_img):
    return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))

# ---- load models ----
yolo = YOLO("yolov8x-seg.pt")
inpaint_pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
    "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
    torch_dtype=torch.float16,
    use_safetensors=True,
    use_auth_token=os.getenv("HF_TOKEN")
).to("cuda")

# ---- processing logic ----
def run_background_removal_and_inpaint(image, prompt, negative_prompt):
    if image is None:
        raise gr.Error("Please upload an image.")

    img_cv = pil_to_cv2(image)
    results = yolo(img_cv)

    if not results or not results[0].masks or len(results[0].masks.data) == 0:
        raise gr.Error("No subject detected in the image. Please upload a clearer photo.")

    mask = results[0].masks.data[0].cpu().numpy()
    binary = (mask > 0.5).astype(np.uint8)
    background_mask = 1 - binary
    kernel = np.ones((15, 15), np.uint8)
    dilated = cv2.dilate(background_mask, kernel, iterations=1)
    inpaint_mask = (dilated * 255).astype(np.uint8)

    mask_pil = cv2_to_pil(inpaint_mask).resize((1024, 1024)).convert("L")
    img_pil = image.resize((1024, 1024)).convert("RGB")

    result = inpaint_pipe(
        prompt=prompt,
        negative_prompt=negative_prompt or "",
        image=img_pil,
        mask_image=mask_pil,
        guidance_scale=10,
        num_inference_steps=40
    ).images[0]

    return result

# ---- Gradio interface ----
with gr.Blocks() as demo:
    gr.Markdown("## 🖼️ Remove & Replace Background")
    gr.Markdown("Upload a headshot, and describe the desired new background.")

    with gr.Row():
        input_img = gr.Image(type="pil", label="Upload Image")
        output_img = gr.Image(type="pil", label="Result")

    with gr.Row():
        prompt = gr.Textbox(
            label="New Background Prompt",
            value="modern open-plan office, soft natural light, minimalistic decor"
        )
        neg_prompt = gr.Textbox(
            label="Negative Prompt",
            value="cartoon, fantasy, dark lighting, painting, anime"
        )

    error_box = gr.Markdown()

    def safe_run(img, prompt, neg_prompt):
        try:
            result = run_background_removal_and_inpaint(img, prompt, neg_prompt)
            return result, ""
        except Exception as e:
            print(f"[ERROR] {type(e).__name__}: {e}")
            return None, f"**❌ Error:** {type(e).__name__}: {e}"

    run_btn = gr.Button("Run Background Inpaint")
    run_btn.click(
        fn=safe_run,
        inputs=[input_img, prompt, neg_prompt],
        outputs=[output_img, error_box]
    )

demo.launch(debug=True)