Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from inference import generate_with_lora | |
from background_edit import run_background_removal_and_inpaint | |
import traceback, torch, gc | |
# βββββββββββββββββββββ Helpers βββββββββββββββββββββ | |
def _print_trace(): | |
traceback.print_exc() | |
def safe_generate_with_lora(*a, **kw): | |
try: | |
return generate_with_lora(*a, **kw) | |
except gr.Error: | |
_print_trace() | |
raise | |
except Exception as e: | |
_print_trace() | |
raise gr.Error(f"Image generation failed: {e}") | |
def unload_models(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def safe_run_background(*args, **kwargs): | |
try: | |
unload_models() # free VRAM before loading the inpainting model | |
return run_background_removal_and_inpaint(*args, **kwargs) | |
except Exception as e: | |
_print_trace() | |
raise gr.Error(f"[Step 2] Background replacement failed: {type(e).__name__}: {e}") | |
# βββββββββββββββββββββ UI βββββββββββββββββββββ | |
shared_output = gr.State() # holds the Step 1 output image | |
original_input = gr.State() # holds the original upload (optional) | |
with gr.Blocks() as demo: | |
demo.queue() # enable batching / concurrency | |
# βββββββββββ STEP 1: Headshot Refinement βββββββββββ | |
with gr.Tab("Step 1: Headshot Refinement"): | |
with gr.Row(): | |
input_image = gr.Image(type="pil", label="Upload Headshot") | |
output_image = gr.Image(type="pil", label="Refined Output") | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="a professional corporate headshot of a confident woman in her 30s with blonde hair" | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="deformed, cartoon, anime, illustration, painting, drawing, sketch, low resolution, blurry, out of focus, pixelated" | |
) | |
with gr.Row(): | |
strength = gr.Slider(0.1, 1.0, value=0.20, step=0.05, label="Strength") | |
guidance = gr.Slider(1, 20, value=17.0, step=0.5, label="Guidance Scale") | |
run_btn = gr.Button("Generate") | |
def _save_to_state(img): | |
return {"step1": img} if img is not None else gr.skip() | |
event = ( | |
run_btn.click( | |
fn=safe_generate_with_lora, | |
inputs=[input_image, prompt, negative_prompt, strength, guidance], | |
outputs=output_image, | |
) | |
.then(_save_to_state, output_image, shared_output) | |
.then(lambda x: x, input_image, original_input) | |
) | |
# βββββββββββ STEP 2: Background Replacement βββββββββββ | |
with gr.Tab("Step 2: Replace Background"): | |
# Show formatted error messages | |
error_box = gr.Markdown(value="", visible=True) | |
with gr.Row(): | |
inpaint_prompt = gr.Textbox( | |
label="New Background Prompt", | |
value="modern open-plan startup office background, natural lighting, glass walls, clean design, minimalistic decor" | |
) | |
inpaint_negative = gr.Textbox( | |
label="Negative Prompt", | |
value="dark lighting, cluttered background, fantasy elements, cartoon, anime, painting, low quality, distorted shapes" | |
) | |
with gr.Row(): | |
inpaint_result = gr.Image(type="pil", label="Inpainted Image") | |
with gr.Row(): | |
inpaint_btn = gr.Button("Remove Background & Inpaint", interactive=False) | |
def guarded_inpaint(img, prompt_bg, neg_bg): | |
if img is None: | |
return None, "**π Error:** No headshot found β please run Step 1 first." | |
try: | |
print("[DEBUG] Starting background removal and inpaintingβ¦", flush=True) | |
result = safe_run_background(img, prompt_bg, neg_bg) | |
return result, "" # Clear error on success | |
except gr.Error as e: | |
print(f"[Step 2 gr.Error] {e}", flush=True) | |
return None, f"**π Step 2 Failed:** {str(e)}" | |
except Exception as e: | |
print(f"[Step 2 UNEXPECTED ERROR] {type(e).__name__}: {e}", flush=True) | |
return None, f"**β Unexpected Error:** {type(e).__name__}: {e}" | |
inpaint_btn.click( | |
fn=guarded_inpaint, | |
inputs=[shared_output, inpaint_prompt, inpaint_negative], | |
outputs=[inpaint_result, error_box], | |
) | |
# Enable Step 2 after Step 1 completes | |
event.then(lambda: gr.update(interactive=True), None, inpaint_btn) | |
demo.launch(debug=True) | |