Spaces:
Running
Running
import streamlit as st | |
from PIL import Image, ImageFilter | |
import matplotlib.pyplot as plt | |
import torch | |
from torchvision import transforms | |
from transformers import AutoModelForImageSegmentation | |
from transformers import pipeline | |
import numpy as np | |
import os | |
def depth_based_blur(orig_image: Image.Image, depth_map: Image.Image, max_blur: float = 15, | |
num_bands: int = 10, invert_depth: bool = True) -> Image.Image: | |
""" | |
Apply a depth-based blur effect to the original image with depth map image. | |
Returns: | |
PIL.Image.Image: The final image with background (farther areas) blurred. | |
""" | |
# Convert depth map to a NumPy array (float32) and normalize to [0, 1] | |
depth_array = np.array(depth_map, dtype=np.float32) | |
d_min, d_max = depth_array.min(), depth_array.max() | |
depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8) | |
if invert_depth: | |
depth_norm = 1.0 - depth_norm | |
orig_rgba = orig_image.convert("RGBA") | |
final_image = orig_rgba.copy() | |
band_edges = np.linspace(0, 1, num_bands + 1) | |
for i in range(num_bands): | |
band_min = band_edges[i] | |
band_max = band_edges[i+1] | |
# Use the midpoint of the band to determine the blur strength. | |
mid = (band_min + band_max) / 2.0 | |
# For example, if mid is lower (i.e. farther away) we want more blur. | |
blur_radius = (1 - mid) * max_blur | |
# Create a blurred version of the original image for this band. | |
blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius)) | |
# Create a mask for pixels whose normalized depth is within this band. | |
band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255 | |
band_mask_pil = Image.fromarray(band_mask, mode="L") | |
final_image = Image.composite(blurred_version, final_image, band_mask_pil) | |
# Convert back to RGB and return. | |
return final_image.convert("RGB") | |
def main(): | |
hf_token = os.environ.get("HF_ACCESS_TOKEN") | |
if hf_token is None: | |
raise RuntimeError("HF_ACCESS_TOKEN is not set. Please add it as a secret.") | |
st.title("Custom Background Blur Demo") | |
# 1. Upload an image | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# 2. Open and display the original image | |
image = Image.open(uploaded_file).convert("RGB") | |
orig_width, orig_height = image.size | |
st.image(image, caption="Original Image", use_container_width=True) | |
st.write("---") | |
st.subheader("Blur Settings") | |
col1, col2 = st.columns(2) | |
device = "cpu" | |
#print(device) | |
# added the tokens | |
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True, use_auth_token=hf_token) | |
torch.set_float32_matmul_precision(['high', 'highest'][0]) | |
model.to(device) | |
model.eval() | |
image_size = (512, 512) | |
transform_image = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
image = image.convert("RGB") | |
input_images = transform_image(image).unsqueeze(0).to(device) | |
# Inference on pytorch | |
with torch.no_grad(): | |
# Get the final output, apply sigmoid to obtain values in [0,1] | |
preds = model(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
# Applying threshold for a binary mask | |
threshold = 0.5 | |
binary_mask = (pred > threshold).float() | |
mask_pil = transforms.ToPILImage()(binary_mask) | |
mask_pil = mask_pil.convert("L") # Ensure it's in grayscale | |
mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0) | |
orig_width, orig_height = image.size | |
mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR) | |
#blur_radius = 15 # adjust radius to control blur strength | |
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") | |
#resized_image = image.resize((512, 512)) | |
results = depth_pipeline(image) | |
#print(results) | |
depth_map_image = results['depth'] | |
with col1: | |
gauss_radius = st.slider("Gaussian Blur Radius", 0, 30, 10, key="gauss") | |
#gaussian_blurred = image.filter(ImageFilter.GaussianBlur(gauss_radius)) | |
blurred_image = image.filter(ImageFilter.GaussianBlur(gauss_radius)) # background is blurred | |
# White (255) in mask_pil = from image1 (orig_image) | |
# Black (0) in mask_pil = from image2 (blurred_image) | |
final_image = Image.composite(image, blurred_image, mask_pil) | |
st.image( | |
final_image, | |
caption=f"Gaussian Blur (radius={gauss_radius})", | |
use_container_width=True | |
) | |
with col2: | |
blur_max = st.slider("Lens Blur Radius", 0, 5, 1, key="lens") | |
output_image = depth_based_blur(image, depth_map_image, max_blur=blur_max, num_bands=40, invert_depth=False) | |
st.image( | |
output_image, | |
caption=f"Lens Blur (blur={blur_max})", | |
use_container_width=True | |
) | |
if __name__ == "__main__": | |
main() | |