import streamlit as st from PIL import Image, ImageFilter import matplotlib.pyplot as plt import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation def depth_based_blur(orig_image: Image.Image, depth_map: Image.Image, max_blur: float = 15, num_bands: int = 10, invert_depth: bool = True) -> Image.Image: """ Apply a depth-based blur effect to the original image with depth map image. Returns: PIL.Image.Image: The final image with background (farther areas) blurred. """ # Convert depth map to a NumPy array (float32) and normalize to [0, 1] depth_array = np.array(depth_map, dtype=np.float32) d_min, d_max = depth_array.min(), depth_array.max() depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8) if invert_depth: depth_norm = 1.0 - depth_norm orig_rgba = orig_image.convert("RGBA") final_image = orig_rgba.copy() # Split the [0,1] depth range into num_bands intervals. band_edges = np.linspace(0, 1, num_bands + 1) for i in range(num_bands): band_min = band_edges[i] band_max = band_edges[i+1] # Use the midpoint of the band to determine the blur strength. mid = (band_min + band_max) / 2.0 # For example, if mid is lower (i.e. farther away) we want more blur. blur_radius = (1 - mid) * max_blur # Create a blurred version of the original image for this band. blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius)) # Create a mask for pixels whose normalized depth is within this band. band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255 band_mask_pil = Image.fromarray(band_mask, mode="L") final_image = Image.composite(blurred_version, final_image, band_mask_pil) # Convert back to RGB and return. return final_image.convert("RGB") def main(): st.title("Custom Background Blur Demo") # 1. Upload an image uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # 2. Open and display the original image image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Original Image", use_column_width=True) st.write("---") st.subheader("Blur Settings") col1, col2 = st.columns(2) device = "cpu" #print(device) model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) torch.set_float32_matmul_precision(['high', 'highest'][0]) model.to(device) model.eval() image_size = (512, 512) transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image = image.convert("RGB") input_images = transform_image(image).unsqueeze(0).to(device) # Inference on pytorch with torch.no_grad(): # Get the final output, apply sigmoid to obtain values in [0,1] preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() # Applying threshold for a binary mask threshold = 0.5 binary_mask = (pred > threshold).float() mask_pil = transforms.ToPILImage()(binary_mask) mask_pil = mask_pil.convert("L") # Ensure it's in grayscale mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0) mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR) #blur_radius = 15 # adjust radius to control blur strength depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") resized_image = image.resize((512, 512)) results = depth_pipeline(resized_image) #print(results) depth_map_image = results['depth'] with col1: gauss_radius = st.slider("Gaussian Blur Radius", 0, 30, 5, key="gauss") #gaussian_blurred = image.filter(ImageFilter.GaussianBlur(gauss_radius)) blurred_image = image.filter(ImageFilter.GaussianBlur(gauss_radius)) # background is blurred # White (255) in mask_pil = from image1 (orig_image) # Black (0) in mask_pil = from image2 (blurred_image) final_image = Image.composite(image, blurred_image, mask_pil) st.image( final_image, caption=f"Gaussian Blur (radius={gauss_radius})", use_column_width=True ) with col2: blur_max = st.slider("Lens Blur Radius", 0, 30, 10, key="lens") output_image = depth_based_blur(resized_image, depth_map_image, max_blur=blur_max, num_bands=40, invert_depth=False) st.image( output_image, caption=f"Lens Blur (blur={blur_max})", use_column_width=True ) if __name__ == "__main__": main()