File size: 2,072 Bytes
f86bc23
411ded6
 
f86bc23
da8d67c
 
 
 
 
 
 
 
 
 
 
 
f86bc23
da8d67c
411ded6
da8d67c
 
 
411ded6
 
 
 
da8d67c
 
 
 
 
 
 
 
 
411ded6
f86bc23
411ded6
0055ad6
da8d67c
0055ad6
 
 
da8d67c
f86bc23
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr
import torch
from PIL import Image, ImageFilter

def load_segmentation_model():
    """
    Loads and caches the segmentation model from BEN2.
    Ensure you have ben2 installed and accessible in your path.
    """
    global seg_model, seg_device
    if "seg_model" not in globals():
        from ben2 import BEN_Base  # Import BEN2
        seg_model = BEN_Base.from_pretrained("PramaLLC/BEN2")
        seg_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        seg_model.to(seg_device).eval()
    return seg_model, seg_device

def process_segmentation_blur(uploaded_image, seg_blur_radius: float):
    """
    Processes the image with segmentation-based blur.
    The image is resized to 512x512. A Gaussian blur with the specified radius is applied,
    then the segmentation mask is computed to composite the sharp foreground over the blurred background.
    """
    if not isinstance(uploaded_image, Image.Image):
        uploaded_image = Image.open(uploaded_image)
    image = uploaded_image.convert("RGB").resize((512, 512))
    seg_model, seg_device = load_segmentation_model()
    blurred_image = image.filter(ImageFilter.GaussianBlur(seg_blur_radius))
    
    # Generate segmentation mask (foreground)
    foreground = seg_model.inference(image, refine_foreground=False)
    foreground_rgba = foreground.convert("RGBA")
    _, _, _, alpha = foreground_rgba.split()
    binary_mask = alpha.point(lambda x: 255 if x > 128 else 0, mode="L")
    final_image = Image.composite(image, blurred_image, binary_mask)
    return final_image

with gr.Blocks() as demo:
    gr.Markdown("# Gaussian Blur using Image Segmentation BEN2 Model.")
    seg_img = gr.Image(type="pil", label="Upload Image")
    seg_blur = gr.Slider(5, 30, value=15, step=1, label="Gaussian Blur Radius")
    seg_out = gr.Image(label="Gaussian-Based Blurred Image")
    seg_button = gr.Button("Process Gaussian Blur")
    seg_button.click(process_segmentation_blur, inputs=[seg_img, seg_blur], outputs=seg_out)

if __name__ == "__main__":
    demo.launch()