import io import numpy as np import torch from PIL import Image, ImageFilter from torchvision import transforms import gradio as gr from transformers import AutoModelForImageSegmentation, pipeline # ---------------------------- # Global Setup and Model Loading # ---------------------------- # Set device (GPU if available, else CPU) device = "cuda" if torch.cuda.is_available() else "cpu" # Load the segmentation model (RMBG-2.0) segmentation_model = AutoModelForImageSegmentation.from_pretrained( 'briaai/RMBG-2.0', trust_remote_code=True ) segmentation_model.to(device) segmentation_model.eval() # Define the image transformation for segmentation (resize to 512x512) image_size = (512, 512) segmentation_transform = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load the depth estimation pipeline (Depth-Anything) depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") # ---------------------------- # Processing Functions # ---------------------------- def segment_and_blur_background(input_image: Image.Image, blur_radius: int = 15, threshold: float = 0.5) -> Image.Image: """ Applies segmentation using the RMBG-2.0 model and then uses the segmentation mask to composite a Gaussian-blurred background with a sharp foreground. """ # Ensure the image is in RGB and get its original dimensions image = input_image.convert("RGB") orig_width, orig_height = image.size # Preprocess image for segmentation input_tensor = segmentation_transform(image).unsqueeze(0).to(device) # Run inference on the segmentation model with torch.no_grad(): preds = segmentation_model(input_tensor)[-1].sigmoid().cpu() pred = preds[0].squeeze() # Create a binary mask using the threshold binary_mask = (pred > threshold).float() mask_pil = transforms.ToPILImage()(binary_mask).convert("L") # Convert grayscale mask to pure binary (0 or 255) mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0) # Resize mask back to the original image dimensions mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR) # Apply Gaussian blur to the entire image for background blurred_image = image.filter(ImageFilter.GaussianBlur(blur_radius)) # Composite the original image (foreground) with the blurred image (background) using the mask final_image = Image.composite(image, blurred_image, mask_pil) return final_image def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image: """ Applies a depth-based blur effect using a depth map produced by Depth-Anything. The effect simulates a lens blur by applying different blur strengths in depth bands. """ # Resize the input image to 512x512 for the depth estimation model image_resized = input_image.resize((512, 512)) # Run depth estimation to obtain the depth map (as a PIL image) results = depth_pipeline(image_resized) depth_map_image = results['depth'] # Convert the depth map to a NumPy array and normalize to [0, 1] depth_array = np.array(depth_map_image, dtype=np.float32) d_min, d_max = depth_array.min(), depth_array.max() depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8) if invert_depth: depth_norm = 1.0 - depth_norm # Convert the resized image to RGBA for compositing orig_rgba = image_resized.convert("RGBA") final_image = orig_rgba.copy() # Divide the normalized depth range into bands band_edges = np.linspace(0, 1, num_bands + 1) for i in range(num_bands): band_min = band_edges[i] band_max = band_edges[i + 1] # Use the midpoint of the band to determine the blur strength. mid = (band_min + band_max) / 2.0 blur_radius_band = (1 - mid) * max_blur # Create a blurred version of the image for this band. blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band)) # Create a mask for pixels whose normalized depth falls within this band. band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255 band_mask_pil = Image.fromarray(band_mask, mode="L") # Composite the blurred version with the current final image using the band mask. final_image = Image.composite(blurred_version, final_image, band_mask_pil) # Return the final composited image as RGB. return final_image.convert("RGB") def process_image(input_image: Image.Image, effect: str) -> Image.Image: """ Dispatch function to apply the selected effect: - "Gaussian Blur Background": uses segmentation and Gaussian blur. - "Depth-based Lens Blur": applies depth-based blur using the estimated depth map. """ if effect == "Gaussian Blur Background": return segment_and_blur_background(input_image) elif effect == "Depth-based Lens Blur": return depth_based_lens_blur(input_image) else: return input_image # ---------------------------- # Gradio Interface # ---------------------------- iface = gr.Interface( fn=process_image, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Radio(choices=["Gaussian Blur Background", "Depth-based Lens Blur"], label="Select Effect") ], outputs=gr.Image(type="pil", label="Output Image"), title="Blur Effects Demo", description=( "Upload an image and choose an effect: " "apply segmentation + Gaussian blurred background, or a depth-based lens blur effect." ) ) if __name__ == "__main__": iface.launch()