theoracle's picture
Fix step error handling and wire up error_box
2c740bc
raw
history blame
3.27 kB
import os
import traceback
from datetime import datetime
import torch, gc
from PIL import Image
import gradio as gr
from inference import generate_with_lora
from background_edit import run_background_removal_and_inpaint
# ───────────────────── Helpers ─────────────────────
def _print_trace():
traceback.print_exc()
def unload_models():
torch.cuda.empty_cache()
gc.collect()
def safe_generate_and_inpaint(image, prompt, negative_prompt, strength, guidance_scale):
try:
if image is None:
raise gr.Error("Please upload an image first.")
# Step 1: Refinement
print("[INFO] Step 1: Generating refined image...", flush=True)
refined = generate_with_lora(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
strength=strength,
guidance_scale=guidance_scale,
)
# Save to disk
os.makedirs("./outputs", exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
path = f"./outputs/step1_result_{ts}.png"
refined.save(path)
# Step 2: Background removal and inpainting
print("[INFO] Step 2: Inpainting background...", flush=True)
unload_models()
result = run_background_removal_and_inpaint(path, prompt, negative_prompt)
return refined, result, ""
except gr.Error as e:
return None, None, f"πŸ›‘ {str(e)}"
except Exception as e:
_print_trace()
return None, None, f"❌ Unexpected Error: {type(e).__name__}: {str(e)}"
# ───────────────────── Gradio UI ─────────────────────
with gr.Blocks() as demo:
gr.Markdown("## πŸ–ΌοΈ AI Headshot Enhancer + Background Replacer")
gr.Markdown("Upload a headshot, adjust the prompt, and click one button. We'll refine the image and replace the background automatically.")
with gr.Row():
input_image = gr.Image(type="pil", label="Upload Headshot")
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
value="a professional corporate headshot of a confident woman in her 30s with blonde hair"
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="deformed, cartoon, anime, illustration, painting, drawing, sketch, low resolution, blurry, out of focus, pixelated"
)
with gr.Row():
strength = gr.Slider(0.1, 1.0, value=0.20, step=0.05, label="Refinement Strength")
guidance = gr.Slider(1, 20, value=17.0, step=0.5, label="Guidance Scale")
go_btn = gr.Button("✨ Run Full Process (Refine + Inpaint)")
with gr.Row():
output_refined = gr.Image(type="pil", label="Step 1: Refined Headshot")
output_final = gr.Image(type="pil", label="Step 2: Final Image with Background")
error_box = gr.Markdown(label="Error", value="", visible=True)
go_btn.click(
fn=safe_generate_and_inpaint,
inputs=[input_image, prompt, negative_prompt, strength, guidance],
outputs=[output_refined, output_final, error_box]
)
demo.launch(debug=True)