Saitama0510's picture
Update app.py
f1c3b2e verified
raw
history blame
5.19 kB
import streamlit as st
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
def depth_based_blur(orig_image: Image.Image, depth_map: Image.Image, max_blur: float = 15,
num_bands: int = 10, invert_depth: bool = True) -> Image.Image:
"""
Apply a depth-based blur effect to the original image with depth map image.
Returns:
PIL.Image.Image: The final image with background (farther areas) blurred.
"""
# Convert depth map to a NumPy array (float32) and normalize to [0, 1]
depth_array = np.array(depth_map, dtype=np.float32)
d_min, d_max = depth_array.min(), depth_array.max()
depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
if invert_depth:
depth_norm = 1.0 - depth_norm
orig_rgba = orig_image.convert("RGBA")
final_image = orig_rgba.copy()
# Split the [0,1] depth range into num_bands intervals.
band_edges = np.linspace(0, 1, num_bands + 1)
for i in range(num_bands):
band_min = band_edges[i]
band_max = band_edges[i+1]
# Use the midpoint of the band to determine the blur strength.
mid = (band_min + band_max) / 2.0
# For example, if mid is lower (i.e. farther away) we want more blur.
blur_radius = (1 - mid) * max_blur
# Create a blurred version of the original image for this band.
blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius))
# Create a mask for pixels whose normalized depth is within this band.
band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
band_mask_pil = Image.fromarray(band_mask, mode="L")
final_image = Image.composite(blurred_version, final_image, band_mask_pil)
# Convert back to RGB and return.
return final_image.convert("RGB")
def main():
st.title("Custom Background Blur Demo")
# 1. Upload an image
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# 2. Open and display the original image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Original Image", use_column_width=True)
st.write("---")
st.subheader("Blur Settings")
col1, col2 = st.columns(2)
device = "cpu"
#print(device)
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
torch.set_float32_matmul_precision(['high', 'highest'][0])
model.to(device)
model.eval()
image_size = (512, 512)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = image.convert("RGB")
input_images = transform_image(image).unsqueeze(0).to(device)
# Inference on pytorch
with torch.no_grad():
# Get the final output, apply sigmoid to obtain values in [0,1]
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Applying threshold for a binary mask
threshold = 0.5
binary_mask = (pred > threshold).float()
mask_pil = transforms.ToPILImage()(binary_mask)
mask_pil = mask_pil.convert("L") # Ensure it's in grayscale
mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
#blur_radius = 15 # adjust radius to control blur strength
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
resized_image = image.resize((512, 512))
results = depth_pipeline(resized_image)
#print(results)
depth_map_image = results['depth']
with col1:
gauss_radius = st.slider("Gaussian Blur Radius", 0, 30, 5, key="gauss")
#gaussian_blurred = image.filter(ImageFilter.GaussianBlur(gauss_radius))
blurred_image = image.filter(ImageFilter.GaussianBlur(gauss_radius)) # background is blurred
# White (255) in mask_pil = from image1 (orig_image)
# Black (0) in mask_pil = from image2 (blurred_image)
final_image = Image.composite(image, blurred_image, mask_pil)
st.image(
final_image,
caption=f"Gaussian Blur (radius={gauss_radius})",
use_column_width=True
)
with col2:
blur_max = st.slider("Lens Blur Radius", 0, 30, 10, key="lens")
output_image = depth_based_blur(resized_image, depth_map_image, max_blur=blur_max, num_bands=40, invert_depth=False)
st.image(
output_image,
caption=f"Lens Blur (blur={blur_max})",
use_column_width=True
)
if __name__ == "__main__":
main()