Spaces:
Sleeping
Sleeping
import streamlit as st | |
st.set_page_config(page_title="Blur App", page_icon=":camera_with_flash:") | |
from PIL import Image, ImageFilter | |
import torch | |
from torchvision import transforms | |
from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation | |
import numpy as np | |
# pointless since no GPU access | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
st.title("Blur an image!") | |
st.subheader("Upload your image and choose a blur style") | |
st.text("A background blur uses segmentation to completely blur all background objects.\n A lens blur applies gradual blur using depth estimation.") | |
st.divider() | |
st.warning("**Note**: The lens blur option takes a long time to process (>5min) since this space isn't linked to a GPU.") | |
def load_gblur_model(): | |
birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True) | |
birefnet.to(device) | |
birefnet.eval() | |
return birefnet | |
def load_lblur_model(): | |
image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf") | |
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device) | |
return model, image_processor | |
gblur_model = load_gblur_model() | |
lblur_model, lblur_img_proc = load_lblur_model() | |
def gaussian_blur(image, blur_str): | |
# Image transform | |
image_size = (512, 512) | |
transform_image = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Prediction | |
input_image = transform_image(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
preds = gblur_model(input_image)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = np.array(pred_pil.resize(image.size)) | |
# Blurring | |
blur = np.array(image.filter(ImageFilter.GaussianBlur(radius=blur_str))) | |
mask = np.expand_dims(mask, axis=2) | |
output_image = np.where(mask, np.array(image), blur) | |
return Image.fromarray(output_image) | |
def lens_blur(image, blur_str): | |
# Process image | |
inputs = lblur_img_proc(images=image, return_tensors="pt").to(device) | |
# Perform forward pass | |
with torch.no_grad(): | |
outputs = lblur_model(**inputs) | |
post_processed_output = lblur_img_proc.post_process_depth_estimation( | |
outputs, target_sizes=[(image.height, image.width)], | |
) | |
# Get depth map | |
depth = post_processed_output[0]["predicted_depth"] | |
depth = (depth - depth.min()) / (depth.max() - depth.min()) | |
depth = depth * 255. | |
depth = depth.detach().cpu().numpy() | |
# Normalize map | |
depth = depth / 255.0 | |
# No of discrete blurs and max blur intensity | |
num_levels = 15 | |
max_radius = blur_str | |
# Pre-compute all blur images | |
blurred_images = [] | |
for i in range(num_levels): | |
radius = (i / (num_levels - 1)) * max_radius | |
if radius < 0.1: | |
blurred_images.append(np.array(image)) | |
else: | |
blurred = np.array(image.filter(ImageFilter.GaussianBlur(radius))) | |
blurred_images.append(blurred) | |
blurred_stack = np.stack(blurred_images, axis=0) | |
# Blend together the images using | |
# Bilinear Interpolation of depth levels | |
h, w = depth.shape | |
y_coords, x_coords = np.indices((h, w)) | |
depth_levels = depth * (num_levels - 1) | |
low_levels = np.floor(depth_levels).astype(int) | |
high_levels = np.clip(low_levels + 1, 0, num_levels - 1) | |
alpha = depth_levels - low_levels | |
pixel_low = blurred_stack[low_levels, y_coords, x_coords, :].astype(np.float32) | |
pixel_high = blurred_stack[high_levels, y_coords, x_coords, :].astype(np.float32) | |
output = (1 - alpha)[..., np.newaxis] * pixel_low + alpha[..., np.newaxis] * pixel_high | |
# Final blurred image | |
output_img = Image.fromarray(np.clip(output, 0, 255).astype(np.uint8)) | |
return output_img | |
left, right = st.columns(2) | |
with left: | |
st.header("Upload your image") | |
up_img = st.file_uploader("Upload image", type=["jpg", "jpeg", "png", "dng", "tiff"]) | |
with right: | |
st.header("Set blur settings.") | |
st.text("Choose a blur style, set the blur strength, and hit 'Apply'!") | |
with st.form("Blur_form"): | |
options = ["Background", "Lens"] | |
selection = st.radio( | |
"Choose a blur type:", | |
options, index=None, | |
captions=[ | |
"Use segmentation", | |
"Use depth estimation"], | |
horizontal=True | |
) | |
blur_level = st.select_slider("Blur Strength", options=["very low","low","medium","high","very high"]) | |
submitted = st.form_submit_button("Apply Blur", disabled=(up_img is None), type="primary") | |
st.divider() | |
blur_levels = {"very low":5, "low":10, "medium":15, "high":20, "very high":30} | |
if up_img: | |
image = Image.open(up_img).convert("RGB") | |
disp_left, disp_right = st.columns(2) | |
with disp_left: | |
og_img = image.copy() | |
st.image(og_img, caption="Original Image", width=300) | |
with disp_right: | |
if submitted and selection in options: | |
blur_str = blur_levels[blur_level] | |
if selection == "Background": | |
with st.spinner(f"Spinning violently around the y-axis..."): | |
result = gaussian_blur(image, blur_str) | |
elif selection == "Lens": | |
with st.spinner(f"One mississippi, two mississippi..."): | |
result = lens_blur(image, blur_str) | |
st.image(result, "Blurred Image", width=300) | |
else: | |
st.write("Waiting for you to select a blur type...") | |