Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image, ImageFilter | |
def load_segmentation_model(): | |
""" | |
Loads and caches the segmentation model from BEN2. | |
Ensure you have ben2 installed and accessible in your path. | |
""" | |
global seg_model, seg_device | |
if "seg_model" not in globals(): | |
from ben2 import BEN_Base # Import BEN2 | |
seg_model = BEN_Base.from_pretrained("PramaLLC/BEN2") | |
seg_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
seg_model.to(seg_device).eval() | |
return seg_model, seg_device | |
def process_segmentation_blur(uploaded_image, seg_blur_radius: float): | |
""" | |
Processes the image with segmentation-based blur. | |
The image is resized to 512x512. A Gaussian blur with the specified radius is applied, | |
then the segmentation mask is computed to composite the sharp foreground over the blurred background. | |
""" | |
if not isinstance(uploaded_image, Image.Image): | |
uploaded_image = Image.open(uploaded_image) | |
image = uploaded_image.convert("RGB").resize((512, 512)) | |
seg_model, seg_device = load_segmentation_model() | |
blurred_image = image.filter(ImageFilter.GaussianBlur(seg_blur_radius)) | |
# Generate segmentation mask (foreground) | |
foreground = seg_model.inference(image, refine_foreground=False) | |
foreground_rgba = foreground.convert("RGBA") | |
_, _, _, alpha = foreground_rgba.split() | |
binary_mask = alpha.point(lambda x: 255 if x > 128 else 0, mode="L") | |
final_image = Image.composite(image, blurred_image, binary_mask) | |
return final_image | |
with gr.Blocks() as demo: | |
gr.Markdown("# Gaussian Blur using Image Segmentation BEN2 Model.") | |
seg_img = gr.Image(type="pil", label="Upload Image") | |
seg_blur = gr.Slider(5, 30, value=15, step=1, label="Gaussian Blur Radius") | |
seg_out = gr.Image(label="Gaussian-Based Blurred Image") | |
seg_button = gr.Button("Process Gaussian Blur") | |
seg_button.click(process_segmentation_blur, inputs=[seg_img, seg_blur], outputs=seg_out) | |
if __name__ == "__main__": | |
demo.launch() | |