shanty2301's picture
text changes
73a8aca verified
raw
history blame contribute delete
5.84 kB
import streamlit as st
st.set_page_config(page_title="Blur App", page_icon=":camera_with_flash:")
from PIL import Image, ImageFilter
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
import numpy as np
# pointless since no GPU access
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.title("Blur an image!")
st.subheader("Upload your image and choose a blur style")
st.text("A background blur uses segmentation to completely blur all background objects.\n A lens blur applies gradual blur using depth estimation.")
st.divider()
st.warning("**Note**: The lens blur option takes a long time to process (>5min) since this space isn't linked to a GPU.")
@st.cache_resource(show_spinner="Pushing pixels...")
def load_gblur_model():
birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
birefnet.to(device)
birefnet.eval()
return birefnet
@st.cache_resource(show_spinner="Running with scissors...")
def load_lblur_model():
image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
return model, image_processor
gblur_model = load_gblur_model()
lblur_model, lblur_img_proc = load_lblur_model()
def gaussian_blur(image, blur_str):
# Image transform
image_size = (512, 512)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prediction
input_image = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = gblur_model(input_image)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = np.array(pred_pil.resize(image.size))
# Blurring
blur = np.array(image.filter(ImageFilter.GaussianBlur(radius=blur_str)))
mask = np.expand_dims(mask, axis=2)
output_image = np.where(mask, np.array(image), blur)
return Image.fromarray(output_image)
def lens_blur(image, blur_str):
# Process image
inputs = lblur_img_proc(images=image, return_tensors="pt").to(device)
# Perform forward pass
with torch.no_grad():
outputs = lblur_model(**inputs)
post_processed_output = lblur_img_proc.post_process_depth_estimation(
outputs, target_sizes=[(image.height, image.width)],
)
# Get depth map
depth = post_processed_output[0]["predicted_depth"]
depth = (depth - depth.min()) / (depth.max() - depth.min())
depth = depth * 255.
depth = depth.detach().cpu().numpy()
# Normalize map
depth = depth / 255.0
# No of discrete blurs and max blur intensity
num_levels = 15
max_radius = blur_str
# Pre-compute all blur images
blurred_images = []
for i in range(num_levels):
radius = (i / (num_levels - 1)) * max_radius
if radius < 0.1:
blurred_images.append(np.array(image))
else:
blurred = np.array(image.filter(ImageFilter.GaussianBlur(radius)))
blurred_images.append(blurred)
blurred_stack = np.stack(blurred_images, axis=0)
# Blend together the images using
# Bilinear Interpolation of depth levels
h, w = depth.shape
y_coords, x_coords = np.indices((h, w))
depth_levels = depth * (num_levels - 1)
low_levels = np.floor(depth_levels).astype(int)
high_levels = np.clip(low_levels + 1, 0, num_levels - 1)
alpha = depth_levels - low_levels
pixel_low = blurred_stack[low_levels, y_coords, x_coords, :].astype(np.float32)
pixel_high = blurred_stack[high_levels, y_coords, x_coords, :].astype(np.float32)
output = (1 - alpha)[..., np.newaxis] * pixel_low + alpha[..., np.newaxis] * pixel_high
# Final blurred image
output_img = Image.fromarray(np.clip(output, 0, 255).astype(np.uint8))
return output_img
left, right = st.columns(2)
with left:
st.header("Upload your image")
up_img = st.file_uploader("Upload image", type=["jpg", "jpeg", "png", "dng", "tiff"])
with right:
st.header("Set blur settings.")
st.text("Choose a blur style, set the blur strength, and hit 'Apply'!")
with st.form("Blur_form"):
options = ["Background", "Lens"]
selection = st.radio(
"Choose a blur type:",
options, index=None,
captions=[
"Use segmentation",
"Use depth estimation"],
horizontal=True
)
blur_level = st.select_slider("Blur Strength", options=["very low","low","medium","high","very high"])
submitted = st.form_submit_button("Apply Blur", disabled=(up_img is None), type="primary")
st.divider()
blur_levels = {"very low":5, "low":10, "medium":15, "high":20, "very high":30}
if up_img:
image = Image.open(up_img).convert("RGB")
disp_left, disp_right = st.columns(2)
with disp_left:
og_img = image.copy()
st.image(og_img, caption="Original Image", width=300)
with disp_right:
if submitted and selection in options:
blur_str = blur_levels[blur_level]
if selection == "Background":
with st.spinner(f"Spinning violently around the y-axis..."):
result = gaussian_blur(image, blur_str)
elif selection == "Lens":
with st.spinner(f"One mississippi, two mississippi..."):
result = lens_blur(image, blur_str)
st.image(result, "Blurred Image", width=300)
else:
st.write("Waiting for you to select a blur type...")